Skip to content

Commit 2afbee4

Browse files
author
dyzheng
committed
Merge branch 'pw_soc_force' of github.com:dyzheng/abacus-develop into pw_fs_soc
2 parents 4294ee7 + 05ccd3b commit 2afbee4

File tree

16 files changed

+813
-91
lines changed

16 files changed

+813
-91
lines changed

source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void XC_Functional::gradcorr(double &etxc, double &vtxc, ModuleBase::matrix &v,
184184
}
185185

186186
gdr2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
187-
h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
187+
if(!is_stress) h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
188188

189189
XC_Functional::grad_rho( rhogsum1 , gdr1, rhopw, ucell->tpiba);
190190
XC_Functional::grad_rho( rhogsum2 , gdr2, rhopw, ucell->tpiba);

source/module_hamilt_pw/hamilt_pwdft/forces_nl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void Forces<FPTYPE, Device>::cal_force_nl(ModuleBase::matrix& forcenl,
4545
break;
4646
}
4747
}
48-
const int npm = ucell_in.get_npol() * nbands_occ;
48+
const int npm = nbands_occ;
4949
// calculate becp = <psi|beta> for all beta functions
5050
nl_tools.cal_becp(ik, npm);
5151
for (int ipol = 0; ipol < 3; ipol++)

source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::allocate_memory(const ModuleBase::matrix
108108
this->atom_na = h_atom_na.data();
109109
this->ppcell_vkb = this->nlpp_->vkb.c;
110110
}
111-
112-
// prepare the memory of the becp and dbecp:
113-
// becp: <Beta(nkb,npw)|psi(nbnd,npw)>
114-
// dbecp: <dBeta(nkb,npw)/dG|psi(nbnd,npw)>
115-
resmem_complex_op()(this->ctx, becp, this->nbands * nkb, "Stress::becp");
116-
resmem_complex_op()(this->ctx, dbecp, 6 * this->nbands * nkb, "Stress::dbecp");
117111
}
118112

119113
template <typename FPTYPE, typename Device>
@@ -163,9 +157,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
163157
{
164158
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_becp");
165159
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
160+
const int npol = this->ucell_->get_npol();
161+
const int size_becp = this->nbands * npol * this->nkb;
162+
const int size_becp_act = npm * npol * this->nkb;
166163
if (this->becp == nullptr)
167164
{
168-
resmem_complex_op()(this->ctx, becp, this->nbands * this->nkb);
165+
resmem_complex_op()(this->ctx, becp, size_becp);
169166
}
170167

171168
// prepare math tools
@@ -249,11 +246,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
249246
}
250247
const char transa = 'C';
251248
const char transb = 'N';
249+
int npm_npol = npm * npol;
252250
gemm_op()(this->ctx,
253251
transa,
254252
transb,
255253
nkb,
256-
npm,
254+
npm_npol,
257255
npw,
258256
&ModuleBase::ONE,
259257
ppcell_vkb,
@@ -268,15 +266,15 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
268266
if (this->device == base_device::GpuDevice)
269267
{
270268
std::complex<FPTYPE>* h_becp = nullptr;
271-
resmem_complex_h_op()(this->cpu_ctx, h_becp, this->nbands * nkb);
272-
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, this->nbands * nkb);
273-
Parallel_Reduce::reduce_pool(h_becp, this->nbands * nkb);
274-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, this->nbands * nkb);
269+
resmem_complex_h_op()(this->cpu_ctx, h_becp, size_becp_act);
270+
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, size_becp_act);
271+
Parallel_Reduce::reduce_pool(h_becp, size_becp_act);
272+
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, size_becp_act);
275273
delmem_complex_h_op()(this->cpu_ctx, h_becp);
276274
}
277275
else
278276
{
279-
Parallel_Reduce::reduce_pool(becp, this->nbands * this->nkb);
277+
Parallel_Reduce::reduce_pool(becp, size_becp_act);
280278
}
281279
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
282280
}
@@ -287,9 +285,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
287285
{
288286
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
289287
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
288+
const int npol = this->ucell_->get_npol();
289+
const int size_becp = this->nbands * npol * this->nkb;
290+
const int npm_npol = npm * npol;
290291
if (this->dbecp == nullptr)
291292
{
292-
resmem_complex_op()(this->ctx, dbecp, this->nbands * this->nkb);
293+
resmem_complex_op()(this->ctx, dbecp, size_becp);
293294
}
294295

295296
// prepare math tools
@@ -401,7 +402,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
401402
transa,
402403
transb,
403404
nkb,
404-
npm,
405+
npm_npol,
405406
npw,
406407
&ModuleBase::ONE,
407408
ppcell_vkb,
@@ -412,6 +413,8 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
412413
dbecp,
413414
nkb);
414415
// calculate stress for target (ipol, jpol)
416+
if(npol == 1)
417+
{
415418
const int current_spin = this->kv_->isk[ik];
416419
cal_stress_nl_op()(this->ctx,
417420
nondiagonal,
@@ -435,20 +438,47 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
435438
becp,
436439
dbecp,
437440
stress);
441+
}
442+
else
443+
{
444+
cal_stress_nl_op()(this->ctx,
445+
ipol,
446+
jpol,
447+
nkb,
448+
npm,
449+
this->ntype,
450+
this->nbands,
451+
ik,
452+
this->nlpp_->deeq_nc.getBound2(),
453+
this->nlpp_->deeq_nc.getBound3(),
454+
this->nlpp_->deeq_nc.getBound4(),
455+
atom_nh,
456+
atom_na,
457+
d_wg,
458+
d_ekb,
459+
qq_nt,
460+
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
461+
becp,
462+
dbecp,
463+
stress);
464+
}
438465
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
439466
}
440467

441468
template <typename FPTYPE, typename Device>
442469
void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
443470
{
444-
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
471+
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_f");
445472
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
473+
const int npol = this->ucell_->get_npol();
474+
const int npm_npol = npm * npol;
475+
const int size_becp = this->nbands * npol * this->nkb;
446476
if (this->dbecp == nullptr)
447477
{
448-
resmem_complex_op()(this->ctx, dbecp, 3 * this->nbands * this->nkb);
478+
resmem_complex_op()(this->ctx, dbecp, 3 * size_becp);
449479
}
450480

451-
std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * this->nbands * this->nkb;
481+
std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * size_becp;
452482
const std::complex<FPTYPE>* vkb_ptr = this->ppcell_vkb;
453483
std::complex<FPTYPE>* vkb_deri_ptr = this->ppcell_vkb;
454484

@@ -481,7 +511,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
481511
transa,
482512
transb,
483513
this->nkb,
484-
npm,
514+
npm_npol,
485515
npw,
486516
&ModuleBase::ONE,
487517
vkb_deri_ptr,
@@ -491,7 +521,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
491521
&ModuleBase::ZERO,
492522
dbecp_ptr,
493523
nkb);
494-
495524
this->revert_vkb(npw, ipol);
496525
this->pre_ik_f = ik;
497526
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
@@ -634,29 +663,56 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
634663
const int current_spin = this->kv_->isk[ik];
635664
const int force_nc = 3;
636665
// calculate the force
637-
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
638-
nondiagonal,
639-
npm,
640-
this->nbands,
641-
this->ntype,
642-
current_spin,
643-
this->nlpp_->deeq.getBound2(),
644-
this->nlpp_->deeq.getBound3(),
645-
this->nlpp_->deeq.getBound4(),
646-
force_nc,
647-
this->nbands,
648-
ik,
649-
nkb,
650-
atom_nh,
651-
atom_na,
652-
this->ucell_->tpiba,
653-
d_wg,
654-
d_ekb,
655-
qq_nt,
656-
deeq,
657-
becp,
658-
dbecp,
659-
force);
666+
if(this->ucell_->get_npol() == 1)
667+
{
668+
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
669+
nondiagonal,
670+
npm,
671+
this->nbands,
672+
this->ntype,
673+
current_spin,
674+
this->nlpp_->deeq.getBound2(),
675+
this->nlpp_->deeq.getBound3(),
676+
this->nlpp_->deeq.getBound4(),
677+
force_nc,
678+
this->nbands,
679+
ik,
680+
nkb,
681+
atom_nh,
682+
atom_na,
683+
this->ucell_->tpiba,
684+
d_wg,
685+
d_ekb,
686+
qq_nt,
687+
deeq,
688+
becp,
689+
dbecp,
690+
force);
691+
}
692+
else
693+
{
694+
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
695+
npm,
696+
this->nbands,
697+
this->ntype,
698+
this->nlpp_->deeq_nc.getBound2(),
699+
this->nlpp_->deeq_nc.getBound3(),
700+
this->nlpp_->deeq_nc.getBound4(),
701+
force_nc,
702+
this->nbands,
703+
ik,
704+
nkb,
705+
atom_nh,
706+
atom_na,
707+
this->ucell_->tpiba,
708+
d_wg,
709+
d_ekb,
710+
qq_nt,
711+
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
712+
becp,
713+
dbecp,
714+
force);
715+
}
660716
}
661717

662718
// template instantiation

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/force_op.cu

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,110 @@ void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_dev
170170
cudaCheckOnDebug();
171171
}
172172

173+
template <typename FPTYPE>
174+
__global__ void cal_force_nl(
175+
const int wg_nc,
176+
const int ntype,
177+
const int deeq_2,
178+
const int deeq_3,
179+
const int deeq_4,
180+
const int forcenl_nc,
181+
const int nbands,
182+
const int ik,
183+
const int nkb,
184+
const int *atom_nh,
185+
const int *atom_na,
186+
const FPTYPE tpiba,
187+
const FPTYPE *d_wg,
188+
const FPTYPE* d_ekb,
189+
const FPTYPE* qq_nt,
190+
const thrust::complex<FPTYPE> *deeq_nc,
191+
const thrust::complex<FPTYPE> *becp,
192+
const thrust::complex<FPTYPE> *dbecp,
193+
FPTYPE *force)
194+
{
195+
const int ib = blockIdx.x / ntype;
196+
const int ib2 = ib * 2;
197+
const int it = blockIdx.x % ntype;
198+
199+
int iat = 0, sum = 0;
200+
for (int ii = 0; ii < it; ii++) {
201+
iat += atom_na[ii];
202+
sum += atom_na[ii] * atom_nh[ii];
203+
}
204+
205+
int Nprojs = atom_nh[it];
206+
FPTYPE fac = d_wg[ik * wg_nc + ib] * 2.0 * tpiba;
207+
FPTYPE ekb_now = d_ekb[ik * wg_nc + ib];
208+
for (int ia = 0; ia < atom_na[it]; ia++) {
209+
for (int ip = threadIdx.x; ip < Nprojs; ip += blockDim.x) {
210+
const int inkb = sum + ip;
211+
for (int ip2 = 0; ip2 < Nprojs; ip2++)
212+
{
213+
// Effective values of the D-eS coefficients
214+
const thrust::complex<FPTYPE> ps_qq = - ekb_now * qq_nt[it * deeq_3 * deeq_4 + ip * deeq_4 + ip2];
215+
const int jnkb = sum + ip2;
216+
const thrust::complex<FPTYPE> ps0 = deeq_nc[((0 * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip2] + ps_qq;
217+
const thrust::complex<FPTYPE> ps1 = deeq_nc[((1 * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip2];
218+
const thrust::complex<FPTYPE> ps2 = deeq_nc[((2 * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip2];
219+
const thrust::complex<FPTYPE> ps3 = deeq_nc[((3 * deeq_2 + iat) * deeq_3 + ip) * deeq_4 + ip2] + ps_qq;
220+
221+
for (int ipol = 0; ipol < 3; ipol++) {
222+
const int index0 = ipol * nbands * 2 * nkb + ib2 * nkb + inkb;
223+
const int index1 = ib2 * nkb + jnkb;
224+
const thrust::complex<FPTYPE> dbb0 = conj(dbecp[index0]) * becp[index1];
225+
const thrust::complex<FPTYPE> dbb1 = conj(dbecp[index0]) * becp[index1 + nkb];
226+
const thrust::complex<FPTYPE> dbb2 = conj(dbecp[index0 + nkb]) * becp[index1];
227+
const thrust::complex<FPTYPE> dbb3 = conj(dbecp[index0 + nkb]) * becp[index1 + nkb];
228+
const FPTYPE tmp = - fac * (ps0 * dbb0 + ps1 * dbb1 + ps2 * dbb2 + ps3 * dbb3).real();
229+
atomicAdd(force + iat * forcenl_nc + ipol, tmp);
230+
}
231+
}
232+
}
233+
iat += 1;
234+
sum += Nprojs;
235+
}
236+
}
237+
238+
// interface for nspin=4 only
239+
template <typename FPTYPE>
240+
void cal_force_nl_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* ctx,
241+
const int& nbands_occ,
242+
const int& wg_nc,
243+
const int& ntype,
244+
const int& deeq_2,
245+
const int& deeq_3,
246+
const int& deeq_4,
247+
const int& forcenl_nc,
248+
const int& nbands,
249+
const int& ik,
250+
const int& nkb,
251+
const int* atom_nh,
252+
const int* atom_na,
253+
const FPTYPE& tpiba,
254+
const FPTYPE* d_wg,
255+
const FPTYPE* d_ekb,
256+
const FPTYPE* qq_nt,
257+
const std::complex<FPTYPE>* deeq_nc,
258+
const std::complex<FPTYPE>* becp,
259+
const std::complex<FPTYPE>* dbecp,
260+
FPTYPE* force)
261+
{
262+
cal_force_nl<FPTYPE><<<nbands_occ * ntype, THREADS_PER_BLOCK>>>(
263+
wg_nc, ntype,
264+
deeq_2, deeq_3, deeq_4,
265+
forcenl_nc, nbands, ik, nkb,
266+
atom_nh, atom_na,
267+
tpiba,
268+
d_wg, d_ekb, qq_nt,
269+
reinterpret_cast<const thrust::complex<FPTYPE>*>(deeq_nc),
270+
reinterpret_cast<const thrust::complex<FPTYPE>*>(becp),
271+
reinterpret_cast<const thrust::complex<FPTYPE>*>(dbecp),
272+
force);// array of data
273+
274+
cudaCheckOnDebug();
275+
}
276+
173277
template <typename FPTYPE>
174278
__global__ void saveVkbValues_(
175279
const int *gcar_zero_ptrs,

source/module_hamilt_pw/hamilt_pwdft/kernels/cuda/nonlocal_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ __global__ void nonlocal_pw(
5454
thrust::complex<FPTYPE>* ps,
5555
const thrust::complex<FPTYPE>* becp)
5656
{
57-
const int ii = blockIdx.x / l2;
58-
const int jj = blockIdx.x % l2;
57+
const int ii = blockIdx.x * 2 / l2;
58+
const int jj = blockIdx.x * 2 % l2;
5959
for (int kk = threadIdx.x; kk < l3; kk += blockDim.x) {
6060
thrust::complex<FPTYPE> res1(0.0, 0.0);
6161
thrust::complex<FPTYPE> res2(0.0, 0.0);
@@ -121,7 +121,7 @@ void hamilt::nonlocal_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
121121
{
122122
// denghui implement 20221109
123123
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
124-
nonlocal_pw<FPTYPE><<<l1 * l2, THREADS_PER_BLOCK>>>(
124+
nonlocal_pw<FPTYPE><<<l1 * l2 / 2, THREADS_PER_BLOCK>>>(
125125
l1, l2, l3, // loop size
126126
sum, iat, nkb, // control params
127127
deeq_x, deeq_y, deeq_z,
@@ -138,4 +138,4 @@ void hamilt::nonlocal_pw_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const b
138138
template struct nonlocal_pw_op<float, base_device::DEVICE_GPU>;
139139
template struct nonlocal_pw_op<double, base_device::DEVICE_GPU>;
140140

141-
} // namespace hamilt
141+
} // namespace hamilt

0 commit comments

Comments
 (0)