Skip to content

Commit ebc3aed

Browse files
author
dyzheng
committed
Refactor: format and less parameters for nonlocal fs kernels
1 parent 2afbee4 commit ebc3aed

File tree

11 files changed

+128
-199
lines changed

11 files changed

+128
-199
lines changed

source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -415,29 +415,27 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
415415
// calculate stress for target (ipol, jpol)
416416
if(npol == 1)
417417
{
418-
const int current_spin = this->kv_->isk[ik];
419-
cal_stress_nl_op()(this->ctx,
420-
nondiagonal,
421-
ipol,
422-
jpol,
423-
nkb,
424-
npm,
425-
this->ntype,
426-
current_spin, // uspp only
427-
this->nbands,
428-
ik,
429-
this->nlpp_->deeq.getBound2(),
430-
this->nlpp_->deeq.getBound3(),
431-
this->nlpp_->deeq.getBound4(),
432-
atom_nh,
433-
atom_na,
434-
d_wg,
435-
d_ekb,
436-
qq_nt,
437-
deeq,
438-
becp,
439-
dbecp,
440-
stress);
418+
const int current_spin = this->kv_->isk[ik];
419+
cal_stress_nl_op()(this->ctx,
420+
nondiagonal,
421+
ipol,
422+
jpol,
423+
nkb,
424+
npm,
425+
this->ntype,
426+
current_spin, // uspp only
427+
this->nlpp_->deeq.getBound2(),
428+
this->nlpp_->deeq.getBound3(),
429+
this->nlpp_->deeq.getBound4(),
430+
atom_nh,
431+
atom_na,
432+
d_wg + this->nbands * ik,
433+
d_ekb + this->nbands * ik,
434+
qq_nt,
435+
deeq,
436+
becp,
437+
dbecp,
438+
stress);
441439
}
442440
else
443441
{
@@ -447,15 +445,13 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
447445
nkb,
448446
npm,
449447
this->ntype,
450-
this->nbands,
451-
ik,
452448
this->nlpp_->deeq_nc.getBound2(),
453449
this->nlpp_->deeq_nc.getBound3(),
454450
this->nlpp_->deeq_nc.getBound4(),
455451
atom_nh,
456452
atom_na,
457-
d_wg,
458-
d_ekb,
453+
d_wg + this->nbands * ik,
454+
d_ekb + this->nbands * ik,
459455
qq_nt,
460456
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
461457
becp,
@@ -668,21 +664,19 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
668664
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
669665
nondiagonal,
670666
npm,
671-
this->nbands,
672667
this->ntype,
673668
current_spin,
674669
this->nlpp_->deeq.getBound2(),
675670
this->nlpp_->deeq.getBound3(),
676671
this->nlpp_->deeq.getBound4(),
677672
force_nc,
678673
this->nbands,
679-
ik,
680674
nkb,
681675
atom_nh,
682676
atom_na,
683677
this->ucell_->tpiba,
684-
d_wg,
685-
d_ekb,
678+
d_wg + this->nbands * ik,
679+
d_ekb + this->nbands * ik,
686680
qq_nt,
687681
deeq,
688682
becp,
@@ -693,20 +687,18 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
693687
{
694688
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
695689
npm,
696-
this->nbands,
697690
this->ntype,
698691
this->nlpp_->deeq_nc.getBound2(),
699692
this->nlpp_->deeq_nc.getBound3(),
700693
this->nlpp_->deeq_nc.getBound4(),
701694
force_nc,
702695
this->nbands,
703-
ik,
704696
nkb,
705697
atom_nh,
706698
atom_na,
707699
this->ucell_->tpiba,
708-
d_wg,
709-
d_ekb,
700+
d_wg + this->nbands * ik,
701+
d_ekb + this->nbands * ik,
710702
qq_nt,
711703
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
712704
becp,

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,13 @@ __global__ void cal_vkb1_nl(
3535
template <typename FPTYPE>
3636
__global__ void cal_force_nl(
3737
const bool nondiagonal,
38-
const int wg_nc,
3938
const int ntype,
4039
const int spin,
4140
const int deeq_2,
4241
const int deeq_3,
4342
const int deeq_4,
4443
const int forcenl_nc,
4544
const int nbands,
46-
const int ik,
4745
const int nkb,
4846
const int *atom_nh,
4947
const int *atom_na,
@@ -65,11 +63,11 @@ __global__ void cal_force_nl(
6563
sum += atom_na[ii] * atom_nh[ii];
6664
}
6765

68-
int Nprojs = atom_nh[it];
69-
FPTYPE fac = d_wg[ik * wg_nc + ib] * 2.0 * tpiba;
70-
FPTYPE ekb_now = d_ekb[ik * wg_nc + ib];
66+
int nproj = atom_nh[it];
67+
FPTYPE fac = d_wg[ib] * 2.0 * tpiba;
68+
FPTYPE ekb_now = d_ekb[ib];
7169
for (int ia = 0; ia < atom_na[it]; ia++) {
72-
for (int ip = threadIdx.x; ip < Nprojs; ip += blockDim.x) {
70+
for (int ip = threadIdx.x; ip < nproj; ip += blockDim.x) {
7371
// FPTYPE ps = GlobalC::ppcell.deeq[spin, iat, ip, ip];
7472
FPTYPE ps = deeq[((spin * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip]
7573
- ekb_now * qq_nt[it * deeq_3 * deeq_4 + ip * deeq_4 + ip];
@@ -85,8 +83,8 @@ __global__ void cal_force_nl(
8583
}
8684

8785
if (nondiagonal) {
88-
//for (int ip2=0; ip2<Nprojs; ip2++)
89-
for (int ip2 = 0; ip2 < Nprojs; ip2++) {
86+
//for (int ip2=0; ip2<nproj; ip2++)
87+
for (int ip2 = 0; ip2 < nproj; ip2++) {
9088
if (ip != ip2) {
9189
const int jnkb = sum + ip2;
9290
ps = deeq[((spin * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip2]
@@ -101,7 +99,7 @@ __global__ void cal_force_nl(
10199
}
102100
}
103101
iat += 1;
104-
sum += Nprojs;
102+
sum += nproj;
105103
}
106104
}
107105

@@ -134,15 +132,13 @@ template <typename FPTYPE>
134132
void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx,
135133
const bool& nondiagonal,
136134
const int& nbands_occ,
137-
const int& wg_nc,
138135
const int& ntype,
139136
const int& spin,
140137
const int& deeq_2,
141138
const int& deeq_3,
142139
const int& deeq_4,
143140
const int& forcenl_nc,
144141
const int& nbands,
145-
const int& ik,
146142
const int& nkb,
147143
const int* atom_nh,
148144
const int* atom_na,
@@ -157,9 +153,9 @@ void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
157153
{
158154
cal_force_nl<FPTYPE><<<nbands_occ * ntype, THREADS_PER_BLOCK>>>(
159155
nondiagonal,
160-
wg_nc, ntype, spin,
156+
ntype, spin,
161157
deeq_2, deeq_3, deeq_4,
162-
forcenl_nc, nbands, ik, nkb,
158+
forcenl_nc, nbands, nkb,
163159
atom_nh, atom_na,
164160
tpiba,
165161
d_wg, d_ekb, qq_nt, deeq,
@@ -172,14 +168,12 @@ void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
172168

173169
template <typename FPTYPE>
174170
__global__ void cal_force_nl(
175-
const int wg_nc,
176171
const int ntype,
177172
const int deeq_2,
178173
const int deeq_3,
179174
const int deeq_4,
180175
const int forcenl_nc,
181176
const int nbands,
182-
const int ik,
183177
const int nkb,
184178
const int *atom_nh,
185179
const int *atom_na,
@@ -202,13 +196,13 @@ __global__ void cal_force_nl(
202196
sum += atom_na[ii] * atom_nh[ii];
203197
}
204198

205-
int Nprojs = atom_nh[it];
206-
FPTYPE fac = d_wg[ik * wg_nc + ib] * 2.0 * tpiba;
207-
FPTYPE ekb_now = d_ekb[ik * wg_nc + ib];
199+
int nproj = atom_nh[it];
200+
FPTYPE fac = d_wg[ib] * 2.0 * tpiba;
201+
FPTYPE ekb_now = d_ekb[ib];
208202
for (int ia = 0; ia < atom_na[it]; ia++) {
209-
for (int ip = threadIdx.x; ip < Nprojs; ip += blockDim.x) {
203+
for (int ip = threadIdx.x; ip < nproj; ip += blockDim.x) {
210204
const int inkb = sum + ip;
211-
for (int ip2 = 0; ip2 < Nprojs; ip2++)
205+
for (int ip2 = 0; ip2 < nproj; ip2++)
212206
{
213207
// Effective values of the D-eS coefficients
214208
const thrust::complex<FPTYPE> ps_qq = - ekb_now * qq_nt[it * deeq_3 * deeq_4 + ip * deeq_4 + ip2];
@@ -231,22 +225,20 @@ __global__ void cal_force_nl(
231225
}
232226
}
233227
iat += 1;
234-
sum += Nprojs;
228+
sum += nproj;
235229
}
236230
}
237231

238232
// interface for nspin=4 only
239233
template <typename FPTYPE>
240234
void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx,
241235
const int& nbands_occ,
242-
const int& wg_nc,
243236
const int& ntype,
244237
const int& deeq_2,
245238
const int& deeq_3,
246239
const int& deeq_4,
247240
const int& forcenl_nc,
248241
const int& nbands,
249-
const int& ik,
250242
const int& nkb,
251243
const int* atom_nh,
252244
const int* atom_na,
@@ -260,9 +252,9 @@ void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
260252
FPTYPE* force)
261253
{
262254
cal_force_nl<FPTYPE><<<nbands_occ * ntype, THREADS_PER_BLOCK>>>(
263-
wg_nc, ntype,
255+
ntype,
264256
deeq_2, deeq_3, deeq_4,
265-
forcenl_nc, nbands, ik, nkb,
257+
forcenl_nc, nbands, nkb,
266258
atom_nh, atom_na,
267259
tpiba,
268260
d_wg, d_ekb, qq_nt,

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/stress_op.cu

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ __global__ void cal_stress_nl(
107107
const int nkb,
108108
const int ntype,
109109
const int spin,
110-
const int wg_nc,
111-
const int ik,
112110
const int deeq_2,
113111
const int deeq_3,
114112
const int deeq_4,
@@ -125,22 +123,25 @@ __global__ void cal_stress_nl(
125123
int ib = blockIdx.x / ntype;
126124
int it = blockIdx.x % ntype;
127125

128-
int iat = 0, sum = 0;
126+
int iat = 0;
127+
int sum = 0;
129128
for (int ii = 0; ii < it; ii++) {
130129
iat += atom_na[ii];
131130
sum += atom_na[ii] * atom_nh[ii];
132131
}
133132

134-
FPTYPE stress_var = 0, fac = d_wg[ik * wg_nc + ib] * 1.0, ekb_now = d_ekb[ik * wg_nc + ib];
135-
const int Nprojs = atom_nh[it];
133+
FPTYPE stress_var = 0;
134+
const FPTYPE fac = d_wg[ib];
135+
const FPTYPE ekb_now = d_ekb[ib];
136+
const int nproj = atom_nh[it];
136137
for (int ia = 0; ia < atom_na[it]; ia++)
137138
{
138-
for (int ii = threadIdx.x; ii < Nprojs * Nprojs; ii += blockDim.x) {
139-
int ip1 = ii / Nprojs, ip2 = ii % Nprojs;
139+
for (int ii = threadIdx.x; ii < nproj * nproj; ii += blockDim.x) {
140+
const int ip1 = ii / nproj, ip2 = ii % nproj;
140141
if(!nondiagonal && ip1 != ip2) {
141142
continue;
142143
}
143-
FPTYPE ps = deeq[((spin * deeq_2 + iat) * deeq_3 + ip1) * deeq_4 + ip2]
144+
const FPTYPE ps = deeq[((spin * deeq_2 + iat) * deeq_3 + ip1) * deeq_4 + ip2]
144145
- ekb_now * qq_nt[it * deeq_3 * deeq_4 + ip1 * deeq_4 + ip2];
145146
const int inkb1 = sum + ip1;
146147
const int inkb2 = sum + ip2;
@@ -149,7 +150,7 @@ __global__ void cal_stress_nl(
149150
stress_var -= ps * fac * dbb;
150151
}
151152
++iat;
152-
sum+=Nprojs;
153+
sum+=nproj;
153154
}//ia
154155
__syncwarp();
155156
warp_reduce(stress_var);
@@ -204,8 +205,6 @@ void cal_stress_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_de
204205
const int& nbands_occ,
205206
const int& ntype,
206207
const int& spin,
207-
const int& wg_nc,
208-
const int& ik,
209208
const int& deeq_2,
210209
const int& deeq_3,
211210
const int& deeq_4,
@@ -226,8 +225,6 @@ void cal_stress_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_de
226225
nkb,
227226
ntype,
228227
spin,
229-
wg_nc,
230-
ik,
231228
deeq_2,
232229
deeq_3,
233230
deeq_4,
@@ -250,8 +247,6 @@ __global__ void cal_stress_nl(
250247
const int jpol,
251248
const int nkb,
252249
const int ntype,
253-
const int wg_nc,
254-
const int ik,
255250
const int deeq_2,
256251
const int deeq_3,
257252
const int deeq_4,
@@ -277,14 +272,14 @@ __global__ void cal_stress_nl(
277272
}
278273

279274
FPTYPE stress_var = 0;
280-
const FPTYPE fac = d_wg[ik * wg_nc + ib] * 1.0;
281-
const FPTYPE ekb_now = d_ekb[ik * wg_nc + ib];
282-
const int Nprojs = atom_nh[it];
275+
const FPTYPE fac = d_wg[ib];
276+
const FPTYPE ekb_now = d_ekb[ib];
277+
const int nproj = atom_nh[it];
283278
for (int ia = 0; ia < atom_na[it]; ia++)
284279
{
285-
for (int ii = threadIdx.x; ii < Nprojs * Nprojs; ii += blockDim.x) {
286-
const int ip1 = ii / Nprojs;
287-
const int ip2 = ii % Nprojs;
280+
for (int ii = threadIdx.x; ii < nproj * nproj; ii += blockDim.x) {
281+
const int ip1 = ii / nproj;
282+
const int ip2 = ii % nproj;
288283
const thrust::complex<FPTYPE> ps_qq = - ekb_now * qq_nt[it * deeq_3 * deeq_4 + ip1 * deeq_4 + ip2];
289284
const thrust::complex<FPTYPE> ps0 = deeq_nc[((iat + ia) * deeq_3 + ip1) * deeq_4 + ip2] + ps_qq;
290285
const thrust::complex<FPTYPE> ps1 = deeq_nc[((1 * deeq_2 + iat + ia) * deeq_3 + ip1) * deeq_4 + ip2];
@@ -300,7 +295,7 @@ __global__ void cal_stress_nl(
300295
stress_var -= fac * (ps0 * dbb0 + ps1 * dbb1 + ps2 * dbb2 + ps3 * dbb3).real();
301296
}
302297
++iat;
303-
sum+=Nprojs;
298+
sum+=nproj;
304299
}//ia
305300
__syncwarp();
306301
warp_reduce(stress_var);
@@ -316,8 +311,6 @@ void cal_stress_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_de
316311
const int& nkb,
317312
const int& nbands_occ,
318313
const int& ntype,
319-
const int& wg_nc,
320-
const int& ik,
321314
const int& deeq_2,
322315
const int& deeq_3,
323316
const int& deeq_4,
@@ -336,8 +329,6 @@ void cal_stress_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_de
336329
jpol,
337330
nkb,
338331
ntype,
339-
wg_nc,
340-
ik,
341332
deeq_2,
342333
deeq_3,
343334
deeq_4,

0 commit comments

Comments
 (0)