Skip to content

Commit b77527a

Browse files
denghuiludyzheng
authored andcommitted
refactor: add code comments for all multi-device ops (#1706)
1 parent 6d5709c commit b77527a

File tree

12 files changed

+465
-5
lines changed

12 files changed

+465
-5
lines changed

source/module_base/include/math_multi_device.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,23 @@ namespace ModuleBase {
88

99
template <typename FPTYPE, typename Device>
1010
struct cal_ylm_real_op {
11+
/// @brief YLM_REAL::Real spherical harmonics ylm(G) up to l=lmax
12+
/// Use Numerical recursive algorithm as given in Numerical Recipes
13+
///
14+
/// Input Parameters
15+
/// @param ctx - which device this function runs on
16+
/// @param ng - number of problem size
17+
/// @param lmax - determined by lmax2
18+
/// @param SQRT2 - ModuleBase::SQRT2
19+
/// @param PI - ModuleBase::PI
20+
/// @param PI_HALF - ModuleBase::PI_HALF
21+
/// @param FOUR_PI - ModuleBase::FOUR_PI,
22+
/// @param SQRT_INVERSE_FOUR_PI - ModuleBase::SQRT_INVERSE_FOUR_PI,
23+
/// @param g - input array with size npw * 3, GlobalC::wf.get_1qvec_cartesian
24+
/// @param p - intermediate array
25+
///
26+
/// Output Parameters
27+
/// @param ylm - output array
1128
void operator() (
1229
const Device *ctx,
1330
const int &ng,

source/module_elecstate/include/elecstate_multi_device.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,38 @@ namespace elecstate{
99

1010
template <typename FPTYPE, typename Device>
1111
struct elecstate_pw_op {
12+
/// @brief Calculate psiToRho output within the band-by-band loop, NSPIN != 4
13+
///
14+
/// Input Parameters
15+
/// @param ctx - which device this function runs on
16+
/// @param spin - current spin
17+
/// @param nrxx - number of planewaves
18+
/// @param weight - input constant
19+
/// @param wfcr - input array, psi in real space
20+
///
21+
/// Output Parameters
22+
/// @param rho - electronic densities
1223
void operator() (
1324
const Device* ctx,
1425
const int& spin,
1526
const int& nrxx,
1627
const FPTYPE& weight,
1728
FPTYPE** rho,
1829
const std::complex<FPTYPE>* wfcr);
19-
30+
31+
/// @brief Calculate psiToRho output within the band-by-band loop, NSPIN == 4
32+
///
33+
/// Input Parameters
34+
/// @param ctx - which device this function runs on
35+
/// @param DOMAG - GlobalV::DOMAG
36+
/// @param DOMAG_Z - GlobalV::DOMAG_Z
37+
/// @param nrxx - number of planewaves
38+
/// @param weight - input constant
39+
/// @param wfcr - input array, psi in real space
40+
/// @param wfcr_another_spin - input array, psi in real space
41+
///
42+
/// Output Parameters
43+
/// @param rho - electronic densities
2044
void operator() (
2145
const Device* ctx,
2246
const bool& DOMAG,

source/module_hamilt/include/ekinetic.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@
77
namespace hamilt {
88
template <typename FPTYPE, typename Device>
99
struct ekinetic_pw_op {
10+
/// @brief Compute the ekinetic potential of hPsi
11+
///
12+
/// Input Parameters
13+
/// \param dev : the type of computing device
14+
/// \param nband : nbands
15+
/// \param npw : number of planewaves of current k point
16+
/// \param max_npw : max number of planewaves of all k points
17+
/// \param tpiba2 : GlobalC::ucell.tpiba2
18+
/// \param spin : current spin
19+
/// \param gk2_ik : GlobalC::wfcpw->gk2
20+
/// \param tmpsi_in : intermediate array
21+
///
22+
/// Output Parameters
23+
/// \param tmhpsi : output array
1024
void operator() (
1125
const Device* dev,
1226
const int& nband,

source/module_hamilt/include/nonlocal.h

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
namespace hamilt {
88
template <typename FPTYPE, typename Device>
99
struct nonlocal_pw_op {
10+
/// @brief Compute the nonlocal potential of hPsi
11+
///
12+
/// Input Parameters
13+
/// \param dev : the type of computing device
14+
/// \param l1 : ucell->atoms[it].na
15+
/// \param l2 : nbands
16+
/// \param l3 : ucell->atoms[it].ncpp.nh
17+
/// \param sum : intermediate value
18+
/// \param iat : intermediate value
19+
/// \param spin : current spin
20+
/// \param nkb : ppcell->nkb, number of kpoints
21+
/// \param deeq_x : second dimension of deeq
22+
/// \param deeq_y : third dimension of deeq
23+
/// \param deeq_z : forth dimension of deeq
24+
/// \param deeq : ppcell->deeq
25+
/// \param becp : intermediate array
26+
///
27+
/// Output Parameters
28+
/// \param ps : output array
1029
void operator() (
1130
const Device* dev,
1231
const int& l1,
@@ -22,7 +41,25 @@ struct nonlocal_pw_op {
2241
const FPTYPE* deeq,
2342
std::complex<FPTYPE>* ps,
2443
const std::complex<FPTYPE>* becp);
25-
44+
45+
/// @brief Compute the nonlocal potential of hPsi, with NSPIN > 2
46+
///
47+
/// Input Parameters
48+
/// \param dev : the type of computing device
49+
/// \param l1 : ucell->atoms[it].na
50+
/// \param l2 : nbands
51+
/// \param l3 : ucell->atoms[it].ncpp.nh
52+
/// \param sum : intermediate value
53+
/// \param iat : intermediate value
54+
/// \param nkb : ppcell->nkb, number of kpoints
55+
/// \param deeq_x : second dimension of deeq
56+
/// \param deeq_y : third dimension of deeq
57+
/// \param deeq_z : forth dimension of deeq
58+
/// \param deeq_nc : ppcell->deeq_nc
59+
/// \param becp : intermediate array
60+
///
61+
/// Output Parameters
62+
/// \param ps : output array
2663
void operator() (
2764
const Device* dev,
2865
const int& l1,

source/module_hamilt/include/veff.h

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,42 @@
77
namespace hamilt {
88
template <typename FPTYPE, typename Device>
99
struct veff_pw_op {
10+
/// @brief Compute the effective potential of hPsi in real space,
11+
/// out[ir] *= in[ir];
12+
///
13+
/// Input Parameters
14+
/// \param dev : the type of computing device
15+
/// \param size : array size
16+
/// \param in : input array, elecstate::Potential::v_effective
17+
///
18+
/// Output Parameters
19+
/// \param out : output array
1020
void operator() (
1121
const Device* dev,
1222
const int& size,
1323
std::complex<FPTYPE>* out,
1424
const FPTYPE* in);
15-
25+
26+
/// @brief Compute the effective potential of hPsi in real space with NSPIN > 2,
27+
///
28+
/// out[ir] = out[ir] * (in[0][ir] + in[3][ir])
29+
/// + out1[ir]
30+
/// * (in[1][ir]
31+
/// - std::complex<FPTYPE>(0.0, 1.0) * in[2][ir]);
32+
///
33+
/// out1[ir] = out1[ir] * (in[0][ir] - in[3][ir])
34+
/// + out[ir]
35+
/// * (in[1][ir]
36+
/// + std::complex<FPTYPE>(0.0, 1.0) * in[2][ir]);
37+
///
38+
/// Input Parameters
39+
/// \param dev : the type of computing device
40+
/// \param size : array size
41+
/// \param in : input array, elecstate::Potential::v_effective
42+
///
43+
/// Output Parameters
44+
/// \param out : output array 1
45+
/// \param out1 : output array 2
1646
void operator() (
1747
const Device* dev,
1848
const int& size,

source/module_hsolver/include/math_kernel.h

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ namespace hsolver
5757

5858
template <typename FPTYPE, typename Device>
5959
struct zdot_real_op {
60+
/// @brief zdot_real_op computes the dot product of the given complex arrays(treated as float arrays).
61+
/// And there's may have MPI communications while enabling planewave parallization strategy.
62+
///
63+
/// Input Parameters
64+
/// \param d : the type of computing device
65+
/// \param dim : array size
66+
/// \param psi_L : input array A
67+
/// \param psi_R : input array B
68+
/// \param reduce : flag to control whether to perform the MPI communications
69+
///
70+
/// \return
71+
/// FPTYPE : dot product result
6072
FPTYPE operator() (
6173
const Device* d,
6274
const int& dim,
@@ -68,6 +80,16 @@ struct zdot_real_op {
6880
// vector operator: result[i] = vector[i] / constant
6981
template <typename FPTYPE, typename Device> struct vector_div_constant_op
7082
{
83+
/// @brief result[i] = vector[i] / constant
84+
///
85+
/// Input Parameters
86+
/// \param d : the type of computing device
87+
/// \param dim : array size
88+
/// \param vector : input array
89+
/// \param constant : input constant
90+
///
91+
/// Output Parameters
92+
/// \param result : output array
7193
void operator()(const Device* d,
7294
const int dim,
7395
std::complex<FPTYPE>* result,
@@ -78,6 +100,17 @@ template <typename FPTYPE, typename Device> struct vector_div_constant_op
78100
// replace vector_div_constant_op : x = alpha * x
79101
template <typename FPTYPE, typename Device> struct scal_op
80102
{
103+
/// @brief x = alpha * x
104+
///
105+
/// Input Parameters
106+
/// \param d : the type of computing device
107+
/// \param N : array size
108+
/// \param alpha : input constant
109+
/// \param X : input array
110+
/// \param incx : computing strip of array X
111+
///
112+
/// Output Parameters
113+
/// \param X : output array
81114
void operator()(const Device* d,
82115
const int& N,
83116
const std::complex<FPTYPE>* alpha,
@@ -88,6 +121,16 @@ template <typename FPTYPE, typename Device> struct scal_op
88121
// vector operator: result[i] = vector1[i](complex) * vector2[i](not complex)
89122
template <typename FPTYPE, typename Device> struct vector_mul_vector_op
90123
{
124+
/// @brief result[i] = vector1[i](complex) * vector2[i](not complex)
125+
///
126+
/// Input Parameters
127+
/// \param d : the type of computing device
128+
/// \param dim : array size
129+
/// \param vector1 : input array A
130+
/// \param vector2 : input array B
131+
///
132+
/// Output Parameters
133+
/// \param result : output array
91134
void operator()(const Device* d,
92135
const int& dim,
93136
std::complex<FPTYPE>* result,
@@ -98,6 +141,16 @@ template <typename FPTYPE, typename Device> struct vector_mul_vector_op
98141
// vector operator: result[i] = vector1[i](complex) / vector2[i](not complex)
99142
template <typename FPTYPE, typename Device> struct vector_div_vector_op
100143
{
144+
/// @brief result[i] = vector1[i](complex) / vector2[i](not complex)
145+
///
146+
/// Input Parameters
147+
/// \param d : the type of computing device
148+
/// \param dim : array size
149+
/// \param vector1 : input array A
150+
/// \param vector2 : input array B
151+
///
152+
/// Output Parameters
153+
/// \param result : output array
101154
void operator()(const Device* d,
102155
const int& dim,
103156
std::complex<FPTYPE>* result,
@@ -108,6 +161,18 @@ template <typename FPTYPE, typename Device> struct vector_div_vector_op
108161
// vector operator: result[i] = vector1[i] * constant1 + vector2[i] * constant2
109162
template <typename FPTYPE, typename Device> struct constantvector_addORsub_constantVector_op
110163
{
164+
/// @brief result[i] = vector1[i] * constant1 + vector2[i] * constant2
165+
///
166+
/// Input Parameters
167+
/// \param d : the type of computing device
168+
/// \param dim : array size
169+
/// \param vector1 : input array A
170+
/// \param constant1 : input constant a
171+
/// \param vector2 : input array B
172+
/// \param constant2 : input constant b
173+
///
174+
/// Output Parameters
175+
/// \param result : output array
111176
void operator()(const Device* d,
112177
const int& dim,
113178
std::complex<FPTYPE>* result,
@@ -120,6 +185,19 @@ template <typename FPTYPE, typename Device> struct constantvector_addORsub_const
120185
// compute Y = alpha * X + Y
121186
template <typename FPTYPE, typename Device> struct axpy_op
122187
{
188+
/// @brief Y = alpha * X + Y
189+
///
190+
/// Input Parameters
191+
/// \param d : the type of computing device
192+
/// \param N : array size
193+
/// \param alpha : input constant alpha
194+
/// \param X : input array X
195+
/// \param incX : computing strip of X
196+
/// \param Y : computing strip of Y
197+
/// \param incY : computing strip of Y
198+
///
199+
/// Output Parameters
200+
/// \param Y : output array Y
123201
void operator()(const Device* d,
124202
const int& N,
125203
const std::complex<FPTYPE>* alpha,
@@ -132,6 +210,24 @@ template <typename FPTYPE, typename Device> struct axpy_op
132210
// compute y = alpha * op(A) * x + beta * y
133211
template <typename FPTYPE, typename Device> struct gemv_op
134212
{
213+
/// @brief y = alpha * op(A) * x + beta * y
214+
///
215+
/// Input Parameters
216+
/// \param d : the type of computing device
217+
/// \param trans : whether to transpose A
218+
/// \param m : first dimension of matrix
219+
/// \param n : second dimension of matrix
220+
/// \param alpha : input constant alpha
221+
/// \param A : input matrix A
222+
/// \param lda : leading dimention of A
223+
/// \param X : input array X
224+
/// \param incx : computing strip of X
225+
/// \param beta : input constant beta
226+
/// \param Y : input array Y
227+
/// \param incy : computing strip of Y
228+
///
229+
/// Output Parameters
230+
/// \param Y : output array Y
135231
void operator()(const Device* d,
136232
const char& trans,
137233
const int& m,
@@ -150,6 +246,26 @@ template <typename FPTYPE, typename Device> struct gemv_op
150246
// compute C = alpha * op(A) * op(B) + beta * C
151247
template <typename FPTYPE, typename Device> struct gemm_op
152248
{
249+
/// @brief C = alpha * op(A) * op(B) + beta * C
250+
///
251+
/// Input Parameters
252+
/// \param d : the type of computing device
253+
/// \param transa : whether to transpose matrix A
254+
/// \param transb : whether to transpose matrix B
255+
/// \param m : first dimension of matrix mulplication
256+
/// \param n : second dimension of matrix mulplication
257+
/// \param k : third dimension of matrix mulplication
258+
/// \param alpha : input constant alpha
259+
/// \param a : input matrix A
260+
/// \param lda : leading dimention of A
261+
/// \param b : input matrix B
262+
/// \param ldb : leading dimention of A
263+
/// \param beta : input constant beta
264+
/// \param c : input matrix C
265+
/// \param ldc : leading dimention of C
266+
///
267+
/// Output Parameters
268+
/// \param c : output matrix C
153269
void operator()(const Device* d,
154270
const char& transa,
155271
const char& transb,
@@ -168,6 +284,16 @@ template <typename FPTYPE, typename Device> struct gemm_op
168284

169285
template <typename FPTYPE, typename Device> struct matrixTranspose_op
170286
{
287+
/// @brief transpose the input matrix
288+
///
289+
/// Input Parameters
290+
/// \param d : the type of computing device
291+
/// \param row : first dimension of matrix
292+
/// \param col : second dimension of matrix
293+
/// \param input_matrix : input matrix
294+
///
295+
/// Output Parameters
296+
/// \param output_matrix : output matrix
171297
void operator()(const Device* d,
172298
const int& row,
173299
const int& col,
@@ -177,6 +303,17 @@ template <typename FPTYPE, typename Device> struct matrixTranspose_op
177303

178304
template <typename FPTYPE, typename Device> struct matrixSetToAnother
179305
{
306+
/// @brief initialize matrix B with A
307+
///
308+
/// Input Parameters
309+
/// \param d : the type of computing device
310+
/// \param n : first dimension of matrix
311+
/// \param A : input matrix A
312+
/// \param LDA : leading dimension of A
313+
/// \param LDB : leading dimension of B
314+
///
315+
/// Output Parameters
316+
/// \param B : output matrix B
180317
void operator()(const Device* d,
181318
const int& n,
182319
const std::complex<FPTYPE>* A,

0 commit comments

Comments
 (0)