@@ -108,12 +108,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::allocate_memory(const ModuleBase::matrix
108108 this ->atom_na = h_atom_na.data ();
109109 this ->ppcell_vkb = this ->nlpp_ ->vkb .c ;
110110 }
111-
112- // prepare the memory of the becp and dbecp:
113- // becp: <Beta(nkb,npw)|psi(nbnd,npw)>
114- // dbecp: <dBeta(nkb,npw)/dG|psi(nbnd,npw)>
115- resmem_complex_op ()(this ->ctx , becp, this ->nbands * nkb, " Stress::becp" );
116- resmem_complex_op ()(this ->ctx , dbecp, 6 * this ->nbands * nkb, " Stress::dbecp" );
117111}
118112
119113template <typename FPTYPE, typename Device>
@@ -163,9 +157,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
163157{
164158 ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_becp" );
165159 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_becp" );
160+ const int npol = this ->ucell_ ->get_npol ();
161+ const int size_becp = this ->nbands * npol * this ->nkb ;
162+ const int size_becp_act = npm * npol * this ->nkb ;
166163 if (this ->becp == nullptr )
167164 {
168- resmem_complex_op ()(this ->ctx , becp, this -> nbands * this -> nkb );
165+ resmem_complex_op ()(this ->ctx , becp, size_becp );
169166 }
170167
171168 // prepare math tools
@@ -249,11 +246,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
249246 }
250247 const char transa = ' C' ;
251248 const char transb = ' N' ;
249+ int npm_npol = npm * npol;
252250 gemm_op ()(this ->ctx ,
253251 transa,
254252 transb,
255253 nkb,
256- npm ,
254+ npm_npol ,
257255 npw,
258256 &ModuleBase::ONE,
259257 ppcell_vkb,
@@ -268,15 +266,15 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(int ik, int npm)
268266 if (this ->device == base_device::GpuDevice)
269267 {
270268 std::complex <FPTYPE>* h_becp = nullptr ;
271- resmem_complex_h_op ()(this ->cpu_ctx , h_becp, this -> nbands * nkb );
272- syncmem_complex_d2h_op ()(this ->cpu_ctx , this ->ctx , h_becp, becp, this -> nbands * nkb );
273- Parallel_Reduce::reduce_pool (h_becp, this -> nbands * nkb );
274- syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , becp, h_becp, this -> nbands * nkb );
269+ resmem_complex_h_op ()(this ->cpu_ctx , h_becp, size_becp_act );
270+ syncmem_complex_d2h_op ()(this ->cpu_ctx , this ->ctx , h_becp, becp, size_becp_act );
271+ Parallel_Reduce::reduce_pool (h_becp, size_becp_act );
272+ syncmem_complex_h2d_op ()(this ->ctx , this ->cpu_ctx , becp, h_becp, size_becp_act );
275273 delmem_complex_h_op ()(this ->cpu_ctx , h_becp);
276274 }
277275 else
278276 {
279- Parallel_Reduce::reduce_pool (becp, this -> nbands * this -> nkb );
277+ Parallel_Reduce::reduce_pool (becp, size_becp_act );
280278 }
281279 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_becp" );
282280}
@@ -287,9 +285,12 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
287285{
288286 ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_s" );
289287 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_s" );
288+ const int npol = this ->ucell_ ->get_npol ();
289+ const int size_becp = this ->nbands * npol * this ->nkb ;
290+ const int npm_npol = npm * npol;
290291 if (this ->dbecp == nullptr )
291292 {
292- resmem_complex_op ()(this ->ctx , dbecp, this -> nbands * this -> nkb );
293+ resmem_complex_op ()(this ->ctx , dbecp, size_becp );
293294 }
294295
295296 // prepare math tools
@@ -401,7 +402,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
401402 transa,
402403 transb,
403404 nkb,
404- npm ,
405+ npm_npol ,
405406 npw,
406407 &ModuleBase::ONE,
407408 ppcell_vkb,
@@ -412,43 +413,68 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(int ik, int npm, int ipol, i
412413 dbecp,
413414 nkb);
414415 // calculate stress for target (ipol, jpol)
415- const int current_spin = this ->kv_ ->isk [ik];
416- cal_stress_nl_op ()(this ->ctx ,
417- nondiagonal,
418- ipol,
419- jpol,
420- nkb,
421- npm,
422- this ->ntype ,
423- current_spin, // uspp only
424- this ->nbands ,
425- ik,
426- this ->nlpp_ ->deeq .getBound2 (),
427- this ->nlpp_ ->deeq .getBound3 (),
428- this ->nlpp_ ->deeq .getBound4 (),
429- atom_nh,
430- atom_na,
431- d_wg,
432- d_ekb,
433- qq_nt,
434- deeq,
435- becp,
436- dbecp,
437- stress);
416+ if (npol == 1 )
417+ {
418+ const int current_spin = this ->kv_ ->isk [ik];
419+ cal_stress_nl_op ()(this ->ctx ,
420+ nondiagonal,
421+ ipol,
422+ jpol,
423+ nkb,
424+ npm,
425+ this ->ntype ,
426+ current_spin, // uspp only
427+ this ->nlpp_ ->deeq .getBound2 (),
428+ this ->nlpp_ ->deeq .getBound3 (),
429+ this ->nlpp_ ->deeq .getBound4 (),
430+ atom_nh,
431+ atom_na,
432+ d_wg + this ->nbands * ik,
433+ d_ekb + this ->nbands * ik,
434+ qq_nt,
435+ deeq,
436+ becp,
437+ dbecp,
438+ stress);
439+ }
440+ else
441+ {
442+ cal_stress_nl_op ()(this ->ctx ,
443+ ipol,
444+ jpol,
445+ nkb,
446+ npm,
447+ this ->ntype ,
448+ this ->nlpp_ ->deeq_nc .getBound2 (),
449+ this ->nlpp_ ->deeq_nc .getBound3 (),
450+ this ->nlpp_ ->deeq_nc .getBound4 (),
451+ atom_nh,
452+ atom_na,
453+ d_wg + this ->nbands * ik,
454+ d_ekb + this ->nbands * ik,
455+ qq_nt,
456+ this ->nlpp_ ->template get_deeq_nc_data <FPTYPE>(),
457+ becp,
458+ dbecp,
459+ stress);
460+ }
438461 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_s" );
439462}
440463
441464template <typename FPTYPE, typename Device>
442465void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
443466{
444- ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_s " );
467+ ModuleBase::TITLE (" FS_Nonlocal_tools" , " cal_dbecp_f " );
445468 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_f" );
469+ const int npol = this ->ucell_ ->get_npol ();
470+ const int npm_npol = npm * npol;
471+ const int size_becp = this ->nbands * npol * this ->nkb ;
446472 if (this ->dbecp == nullptr )
447473 {
448- resmem_complex_op ()(this ->ctx , dbecp, 3 * this -> nbands * this -> nkb );
474+ resmem_complex_op ()(this ->ctx , dbecp, 3 * size_becp );
449475 }
450476
451- std::complex <FPTYPE>* dbecp_ptr = this ->dbecp + ipol * this -> nbands * this -> nkb ;
477+ std::complex <FPTYPE>* dbecp_ptr = this ->dbecp + ipol * size_becp ;
452478 const std::complex <FPTYPE>* vkb_ptr = this ->ppcell_vkb ;
453479 std::complex <FPTYPE>* vkb_deri_ptr = this ->ppcell_vkb ;
454480
@@ -481,7 +507,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
481507 transa,
482508 transb,
483509 this ->nkb ,
484- npm ,
510+ npm_npol ,
485511 npw,
486512 &ModuleBase::ONE,
487513 vkb_deri_ptr,
@@ -491,7 +517,6 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(int ik, int npm, int ipol)
491517 &ModuleBase::ZERO,
492518 dbecp_ptr,
493519 nkb);
494-
495520 this ->revert_vkb (npw, ipol);
496521 this ->pre_ik_f = ik;
497522 ModuleBase::timer::tick (" FS_Nonlocal_tools" , " cal_dbecp_f" );
@@ -634,29 +659,52 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_force(int ik, int npm, FPTYPE* force
634659 const int current_spin = this ->kv_ ->isk [ik];
635660 const int force_nc = 3 ;
636661 // calculate the force
637- cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
638- nondiagonal,
639- npm,
640- this ->nbands ,
641- this ->ntype ,
642- current_spin,
643- this ->nlpp_ ->deeq .getBound2 (),
644- this ->nlpp_ ->deeq .getBound3 (),
645- this ->nlpp_ ->deeq .getBound4 (),
646- force_nc,
647- this ->nbands ,
648- ik,
649- nkb,
650- atom_nh,
651- atom_na,
652- this ->ucell_ ->tpiba ,
653- d_wg,
654- d_ekb,
655- qq_nt,
656- deeq,
657- becp,
658- dbecp,
659- force);
662+ if (this ->ucell_ ->get_npol () == 1 )
663+ {
664+ cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
665+ nondiagonal,
666+ npm,
667+ this ->ntype ,
668+ current_spin,
669+ this ->nlpp_ ->deeq .getBound2 (),
670+ this ->nlpp_ ->deeq .getBound3 (),
671+ this ->nlpp_ ->deeq .getBound4 (),
672+ force_nc,
673+ this ->nbands ,
674+ nkb,
675+ atom_nh,
676+ atom_na,
677+ this ->ucell_ ->tpiba ,
678+ d_wg + this ->nbands * ik,
679+ d_ekb + this ->nbands * ik,
680+ qq_nt,
681+ deeq,
682+ becp,
683+ dbecp,
684+ force);
685+ }
686+ else
687+ {
688+ cal_force_nl_op<FPTYPE, Device>()(this ->ctx ,
689+ npm,
690+ this ->ntype ,
691+ this ->nlpp_ ->deeq_nc .getBound2 (),
692+ this ->nlpp_ ->deeq_nc .getBound3 (),
693+ this ->nlpp_ ->deeq_nc .getBound4 (),
694+ force_nc,
695+ this ->nbands ,
696+ nkb,
697+ atom_nh,
698+ atom_na,
699+ this ->ucell_ ->tpiba ,
700+ d_wg + this ->nbands * ik,
701+ d_ekb + this ->nbands * ik,
702+ qq_nt,
703+ this ->nlpp_ ->template get_deeq_nc_data <FPTYPE>(),
704+ becp,
705+ dbecp,
706+ force);
707+ }
660708}
661709
662710// template instantiation
0 commit comments