@@ -108,12 +108,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::allocate_memory(const ModuleBase::matrix
108108 this ->atom_na = h_atom_na.data ();
109109 this ->ppcell_vkb = this ->nlpp_ ->vkb .c ;
110110 }
111-
112- // prepare the memory of the becp and dbecp:
113- // becp: <Beta(nkb,npw)|psi(nbnd,npw)>
114- // dbecp: <dBeta(nkb,npw)/dG|psi(nbnd,npw)>
115- resmem_complex_op ()(this ->ctx , becp, this ->nbands * nkb, " Stress::becp" );
116- resmem_complex_op ()(this ->ctx , dbecp, 6 * this ->nbands * nkb, " Stress::dbecp" );
117111}
118112
119113template <typename FPTYPE, typename Device>
@@ -163,9 +157,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
163157{
164158 ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_becp" );
165159 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_becp" );
160+ const int npol = this ->ucell_ ->get_npol ();
161+ const int size_becp = this ->nbands * npol * this ->nkb ;
162+ const int size_becp_act = npm * npol * this ->nkb ;
166163 if (this ->becp == nullptr )
167164 {
168- resmem_complex_op ()(this ->ctx , becp, this -> nbands * this -> nkb );
165+ resmem_complex_op ()(this ->ctx , becp, size_becp );
169166 }
170167
171168 // prepare math tools
@@ -249,11 +246,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
249246 }
250247 const char transa = ' C' ;
251248 const char transb = ' N' ;
249+ int npm_npol = npm * npol;
252250 gemm_op ()(this ->ctx ,
253251 transa,
254252 transb,
255253 nkb,
256- npm ,
254+ npm_npol ,
257255 npw,
258256 &ModuleBase::ONE,
259257 ppcell_vkb,
@@ -268,15 +266,15 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
268266 if (this ->device == base_device::GpuDevice)
269267 {
270268 std::complex <FPTYPE>* h_becp = nullptr ;
271- resmem_complex_h_op ()(this ->cpu_ctx , h_becp, this -> nbands * nkb );
272- syncmem_complex_d2h_op ()(this ->cpu_ctx , this ->ctx , h_becp, becp, this -> nbands * nkb );
273- Parallel_Reduce::reduce_pool (h_becp, this -> nbands * nkb );
274- syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , becp, h_becp, this -> nbands * nkb );
269+ resmem_complex_h_op ()(this ->cpu_ctx , h_becp, size_becp_act );
270+ syncmem_complex_d2h_op ()(this ->cpu_ctx , this ->ctx , h_becp, becp, size_becp_act );
271+ Parallel_Reduce::reduce_pool (h_becp, size_becp_act );
272+ syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , becp, h_becp, size_becp_act );
275273 delmem_complex_h_op ()(this ->cpu_ctx , h_becp);
276274 }
277275 else
278276 {
279- Parallel_Reduce::reduce_pool (becp, this -> nbands * this -> nkb );
277+ Parallel_Reduce::reduce_pool (becp, size_becp_act );
280278 }
281279 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_becp" );
282280}
@@ -287,9 +285,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
287285{
288286 ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_s" );
289287 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_s" );
288+ const int npol = this ->ucell_ ->get_npol ();
289+ const int size_becp = this ->nbands * npol * this ->nkb ;
290+ const int npm_npol = npm * npol;
290291 if (this ->dbecp == nullptr )
291292 {
292- resmem_complex_op ()(this ->ctx , dbecp, this -> nbands * this -> nkb );
293+ resmem_complex_op ()(this ->ctx , dbecp, size_becp );
293294 }
294295
295296 // prepare math tools
@@ -401,7 +402,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
401402 transa,
402403 transb,
403404 nkb,
404- npm ,
405+ npm_npol ,
405406 npw,
406407 &ModuleBase::ONE,
407408 ppcell_vkb,
@@ -412,6 +413,8 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
412413 dbecp,
413414 nkb);
414415 // calculate stress for target (ipol, jpol)
416+ if (npol == 1 )
417+ {
415418 const int current_spin = this ->kv_ ->isk [ik];
416419 cal_stress_nl_op ()(this ->ctx ,
417420 nondiagonal,
@@ -435,20 +438,47 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
435438 becp,
436439 dbecp,
437440 stress);
441+ }
442+ else
443+ {
444+ cal_stress_nl_op ()(this ->ctx ,
445+ ipol,
446+ jpol,
447+ nkb,
448+ npm,
449+ this ->ntype ,
450+ this ->nbands ,
451+ ik,
452+ this ->nlpp_ ->deeq_nc .getBound2 (),
453+ this ->nlpp_ ->deeq_nc .getBound3 (),
454+ this ->nlpp_ ->deeq_nc .getBound4 (),
455+ atom_nh,
456+ atom_na,
457+ d_wg,
458+ d_ekb,
459+ qq_nt,
460+ this ->nlpp_ ->template get_deeq_nc_data <FPTYPE>(),
461+ becp,
462+ dbecp,
463+ stress);
464+ }
438465 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_s" );
439466}
440467
441468template <typename FPTYPE, typename Device>
442469void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
443470{
444- ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_s " );
471+ ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_f " );
445472 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_f" );
473+ const int npol = this ->ucell_ ->get_npol ();
474+ const int npm_npol = npm * npol;
475+ const int size_becp = this ->nbands * npol * this ->nkb ;
446476 if (this ->dbecp == nullptr )
447477 {
448- resmem_complex_op ()(this ->ctx , dbecp, 3 * this -> nbands * this -> nkb );
478+ resmem_complex_op ()(this ->ctx , dbecp, 3 * size_becp );
449479 }
450480
451- std::complex <FPTYPE>* dbecp_ptr = this ->dbecp + ipol * this -> nbands * this -> nkb ;
481+ std::complex <FPTYPE>* dbecp_ptr = this ->dbecp + ipol * size_becp ;
452482 const std::complex <FPTYPE>* vkb_ptr = this ->ppcell_vkb ;
453483 std::complex <FPTYPE>* vkb_deri_ptr = this ->ppcell_vkb ;
454484
@@ -481,7 +511,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
481511 transa,
482512 transb,
483513 this ->nkb ,
484- npm ,
514+ npm_npol ,
485515 npw,
486516 &ModuleBase::ONE,
487517 vkb_deri_ptr,
@@ -491,7 +521,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
491521 &ModuleBase::ZERO,
492522 dbecp_ptr,
493523 nkb);
494-
495524 this ->revert_vkb (npw, ipol);
496525 this ->pre_ik_f = ik;
497526 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_f" );
@@ -634,29 +663,56 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
634663 const int current_spin = this ->kv_ ->isk [ik];
635664 const int force_nc = 3 ;
636665 // calculate the force
637- cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
638- nondiagonal,
639- npm,
640- this ->nbands ,
641- this ->ntype ,
642- current_spin,
643- this ->nlpp_ ->deeq .getBound2 (),
644- this ->nlpp_ ->deeq .getBound3 (),
645- this ->nlpp_ ->deeq .getBound4 (),
646- force_nc,
647- this ->nbands ,
648- ik,
649- nkb,
650- atom_nh,
651- atom_na,
652- this ->ucell_ ->tpiba ,
653- d_wg,
654- d_ekb,
655- qq_nt,
656- deeq,
657- becp,
658- dbecp,
659- force);
666+ if (this ->ucell_ ->get_npol () == 1 )
667+ {
668+ cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
669+ nondiagonal,
670+ npm,
671+ this ->nbands ,
672+ this ->ntype ,
673+ current_spin,
674+ this ->nlpp_ ->deeq .getBound2 (),
675+ this ->nlpp_ ->deeq .getBound3 (),
676+ this ->nlpp_ ->deeq .getBound4 (),
677+ force_nc,
678+ this ->nbands ,
679+ ik,
680+ nkb,
681+ atom_nh,
682+ atom_na,
683+ this ->ucell_ ->tpiba ,
684+ d_wg,
685+ d_ekb,
686+ qq_nt,
687+ deeq,
688+ becp,
689+ dbecp,
690+ force);
691+ }
692+ else
693+ {
694+ cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
695+ npm,
696+ this ->nbands ,
697+ this ->ntype ,
698+ this ->nlpp_ ->deeq_nc .getBound2 (),
699+ this ->nlpp_ ->deeq_nc .getBound3 (),
700+ this ->nlpp_ ->deeq_nc .getBound4 (),
701+ force_nc,
702+ this ->nbands ,
703+ ik,
704+ nkb,
705+ atom_nh,
706+ atom_na,
707+ this ->ucell_ ->tpiba ,
708+ d_wg,
709+ d_ekb,
710+ qq_nt,
711+ this ->nlpp_ ->template get_deeq_nc_data <FPTYPE>(),
712+ becp,
713+ dbecp,
714+ force);
715+ }
660716}
661717
662718// template instantiation
0 commit comments