Skip to content

Commit e276480

Browse files
dyzhengdyzhengpre-commit-ci-lite[bot]
authored
Fix: force and stress calculation with noncollinear-spin or SOC for PW code (#5377)
* Fix: noncollinear-spin force and stress with PW * Fix: bug of pw force&stress calculation * fix: soc error and cuda kernels * Fix: CUDA and ROCM * [pre-commit.ci lite] apply automatic fixes * Fix: nspin=4 error on DCU/GPU * Refactor: format and less parameters for nonlocal fs kernels * Test: add force and stress test for noncollinear-spin and soc --------- Co-authored-by: dyzheng <[email protected]> Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
1 parent 6cdb848 commit e276480

File tree

26 files changed

+886
-227
lines changed

26 files changed

+886
-227
lines changed

source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void XC_Functional::gradcorr(double &etxc, double &vtxc, ModuleBase::matrix &v,
184184
}
185185

186186
gdr2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
187-
h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
187+
if(!is_stress) h2 = new ModuleBase::Vector3<double>[rhopw->nrxx];
188188

189189
XC_Functional::grad_rho( rhogsum1 , gdr1, rhopw, ucell->tpiba);
190190
XC_Functional::grad_rho( rhogsum2 , gdr2, rhopw, ucell->tpiba);

source/module_hamilt_pw/hamilt_pwdft/forces_nl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void Forces<FPTYPE, Device>::cal_force_nl(ModuleBase::matrix& forcenl,
4545
break;
4646
}
4747
}
48-
const int npm = ucell_in.get_npol() * nbands_occ;
48+
const int npm = nbands_occ;
4949
// calculate becp = <psi|beta> for all beta functions
5050
nl_tools.cal_becp(ik, npm);
5151
for (int ipol = 0; ipol < 3; ipol++)

source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp

Lines changed: 114 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::allocate_memory(const ModuleBase::matrix
108108
this->atom_na = h_atom_na.data();
109109
this->ppcell_vkb = this->nlpp_->vkb.c;
110110
}
111-
112-
// prepare the memory of the becp and dbecp:
113-
// becp: <Beta(nkb,npw)|psi(nbnd,npw)>
114-
// dbecp: <dBeta(nkb,npw)/dG|psi(nbnd,npw)>
115-
resmem_complex_op()(this->ctx, becp, this->nbands * nkb, "Stress::becp");
116-
resmem_complex_op()(this->ctx, dbecp, 6 * this->nbands * nkb, "Stress::dbecp");
117111
}
118112

119113
template <typename FPTYPE, typename Device>
@@ -163,9 +157,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
163157
{
164158
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_becp");
165159
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
160+
const int npol = this->ucell_->get_npol();
161+
const int size_becp = this->nbands * npol * this->nkb;
162+
const int size_becp_act = npm * npol * this->nkb;
166163
if (this->becp == nullptr)
167164
{
168-
resmem_complex_op()(this->ctx, becp, this->nbands * this->nkb);
165+
resmem_complex_op()(this->ctx, becp, size_becp);
169166
}
170167

171168
// prepare math tools
@@ -249,11 +246,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
249246
}
250247
const char transa = 'C';
251248
const char transb = 'N';
249+
int npm_npol = npm * npol;
252250
gemm_op()(this->ctx,
253251
transa,
254252
transb,
255253
nkb,
256-
npm,
254+
npm_npol,
257255
npw,
258256
&ModuleBase::ONE,
259257
ppcell_vkb,
@@ -268,15 +266,15 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
268266
if (this->device == base_device::GpuDevice)
269267
{
270268
std::complex<FPTYPE>* h_becp = nullptr;
271-
resmem_complex_h_op()(this->cpu_ctx, h_becp, this->nbands * nkb);
272-
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, this->nbands * nkb);
273-
Parallel_Reduce::reduce_pool(h_becp, this->nbands * nkb);
274-
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, this->nbands * nkb);
269+
resmem_complex_h_op()(this->cpu_ctx, h_becp, size_becp_act);
270+
syncmem_complex_d2h_op()(this->cpu_ctx, this->ctx, h_becp, becp, size_becp_act);
271+
Parallel_Reduce::reduce_pool(h_becp, size_becp_act);
272+
syncmem_complex_h2d_op()(this->ctx, this->cpu_ctx, becp, h_becp, size_becp_act);
275273
delmem_complex_h_op()(this->cpu_ctx, h_becp);
276274
}
277275
else
278276
{
279-
Parallel_Reduce::reduce_pool(becp, this->nbands * this->nkb);
277+
Parallel_Reduce::reduce_pool(becp, size_becp_act);
280278
}
281279
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_becp");
282280
}
@@ -287,9 +285,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
287285
{
288286
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
289287
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
288+
const int npol = this->ucell_->get_npol();
289+
const int size_becp = this->nbands * npol * this->nkb;
290+
const int npm_npol = npm * npol;
290291
if (this->dbecp == nullptr)
291292
{
292-
resmem_complex_op()(this->ctx, dbecp, this->nbands * this->nkb);
293+
resmem_complex_op()(this->ctx, dbecp, size_becp);
293294
}
294295

295296
// prepare math tools
@@ -401,7 +402,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
401402
transa,
402403
transb,
403404
nkb,
404-
npm,
405+
npm_npol,
405406
npw,
406407
&ModuleBase::ONE,
407408
ppcell_vkb,
@@ -412,43 +413,68 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
412413
dbecp,
413414
nkb);
414415
// calculate stress for target (ipol, jpol)
415-
const int current_spin = this->kv_->isk[ik];
416-
cal_stress_nl_op()(this->ctx,
417-
nondiagonal,
418-
ipol,
419-
jpol,
420-
nkb,
421-
npm,
422-
this->ntype,
423-
current_spin, // uspp only
424-
this->nbands,
425-
ik,
426-
this->nlpp_->deeq.getBound2(),
427-
this->nlpp_->deeq.getBound3(),
428-
this->nlpp_->deeq.getBound4(),
429-
atom_nh,
430-
atom_na,
431-
d_wg,
432-
d_ekb,
433-
qq_nt,
434-
deeq,
435-
becp,
436-
dbecp,
437-
stress);
416+
if(npol == 1)
417+
{
418+
const int current_spin = this->kv_->isk[ik];
419+
cal_stress_nl_op()(this->ctx,
420+
nondiagonal,
421+
ipol,
422+
jpol,
423+
nkb,
424+
npm,
425+
this->ntype,
426+
current_spin, // uspp only
427+
this->nlpp_->deeq.getBound2(),
428+
this->nlpp_->deeq.getBound3(),
429+
this->nlpp_->deeq.getBound4(),
430+
atom_nh,
431+
atom_na,
432+
d_wg + this->nbands * ik,
433+
d_ekb + this->nbands * ik,
434+
qq_nt,
435+
deeq,
436+
becp,
437+
dbecp,
438+
stress);
439+
}
440+
else
441+
{
442+
cal_stress_nl_op()(this->ctx,
443+
ipol,
444+
jpol,
445+
nkb,
446+
npm,
447+
this->ntype,
448+
this->nlpp_->deeq_nc.getBound2(),
449+
this->nlpp_->deeq_nc.getBound3(),
450+
this->nlpp_->deeq_nc.getBound4(),
451+
atom_nh,
452+
atom_na,
453+
d_wg + this->nbands * ik,
454+
d_ekb + this->nbands * ik,
455+
qq_nt,
456+
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
457+
becp,
458+
dbecp,
459+
stress);
460+
}
438461
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_s");
439462
}
440463

441464
template <typename FPTYPE, typename Device>
442465
void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
443466
{
444-
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_s");
467+
ModuleBase::TITLE("FS_Nonlocal_tools", "cal_dbecp_f");
445468
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
469+
const int npol = this->ucell_->get_npol();
470+
const int npm_npol = npm * npol;
471+
const int size_becp = this->nbands * npol * this->nkb;
446472
if (this->dbecp == nullptr)
447473
{
448-
resmem_complex_op()(this->ctx, dbecp, 3 * this->nbands * this->nkb);
474+
resmem_complex_op()(this->ctx, dbecp, 3 * size_becp);
449475
}
450476

451-
std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * this->nbands * this->nkb;
477+
std::complex<FPTYPE>* dbecp_ptr = this->dbecp + ipol * size_becp;
452478
const std::complex<FPTYPE>* vkb_ptr = this->ppcell_vkb;
453479
std::complex<FPTYPE>* vkb_deri_ptr = this->ppcell_vkb;
454480

@@ -481,7 +507,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
481507
transa,
482508
transb,
483509
this->nkb,
484-
npm,
510+
npm_npol,
485511
npw,
486512
&ModuleBase::ONE,
487513
vkb_deri_ptr,
@@ -491,7 +517,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
491517
&ModuleBase::ZERO,
492518
dbecp_ptr,
493519
nkb);
494-
495520
this->revert_vkb(npw, ipol);
496521
this->pre_ik_f = ik;
497522
ModuleBase::timer::tick("FS_Nonlocal_tools", "cal_dbecp_f");
@@ -634,29 +659,52 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
634659
const int current_spin = this->kv_->isk[ik];
635660
const int force_nc = 3;
636661
// calculate the force
637-
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
638-
nondiagonal,
639-
npm,
640-
this->nbands,
641-
this->ntype,
642-
current_spin,
643-
this->nlpp_->deeq.getBound2(),
644-
this->nlpp_->deeq.getBound3(),
645-
this->nlpp_->deeq.getBound4(),
646-
force_nc,
647-
this->nbands,
648-
ik,
649-
nkb,
650-
atom_nh,
651-
atom_na,
652-
this->ucell_->tpiba,
653-
d_wg,
654-
d_ekb,
655-
qq_nt,
656-
deeq,
657-
becp,
658-
dbecp,
659-
force);
662+
if(this->ucell_->get_npol() == 1)
663+
{
664+
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
665+
nondiagonal,
666+
npm,
667+
this->ntype,
668+
current_spin,
669+
this->nlpp_->deeq.getBound2(),
670+
this->nlpp_->deeq.getBound3(),
671+
this->nlpp_->deeq.getBound4(),
672+
force_nc,
673+
this->nbands,
674+
nkb,
675+
atom_nh,
676+
atom_na,
677+
this->ucell_->tpiba,
678+
d_wg + this->nbands * ik,
679+
d_ekb + this->nbands * ik,
680+
qq_nt,
681+
deeq,
682+
becp,
683+
dbecp,
684+
force);
685+
}
686+
else
687+
{
688+
cal_force_nl_op<FPTYPE, Device>()(this->ctx,
689+
npm,
690+
this->ntype,
691+
this->nlpp_->deeq_nc.getBound2(),
692+
this->nlpp_->deeq_nc.getBound3(),
693+
this->nlpp_->deeq_nc.getBound4(),
694+
force_nc,
695+
this->nbands,
696+
nkb,
697+
atom_nh,
698+
atom_na,
699+
this->ucell_->tpiba,
700+
d_wg + this->nbands * ik,
701+
d_ekb + this->nbands * ik,
702+
qq_nt,
703+
this->nlpp_->template get_deeq_nc_data<FPTYPE>(),
704+
becp,
705+
dbecp,
706+
force);
707+
}
660708
}
661709

662710
// template instantiation

0 commit comments

Comments
 (0)