Skip to content

Commit 63a7ab6

Browse files
zgn-26714dzzz2001
andauthored
cherry-pick #6392 to support GPU version of cal_force_cc under LCAO basis set (#6583)
* Perf: support GPU version of cal_force_cc with LCAO basis (#6392) * support GPU version of cal_force_cc with LCAO basis * fix a bug * Apply the changes related to force calculation from the develop branch, fixing the memory error bug that occurs during direct cherry-pick. * Reverted the process of removing PAW. * fix force calc --------- Co-authored-by: dzzz2001 <[email protected]>
1 parent be120f1 commit 63a7ab6

File tree

10 files changed

+581
-683
lines changed

10 files changed

+581
-683
lines changed

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/nonlocal_new.h"
2323

2424
template <typename T>
25-
Force_Stress_LCAO<T>::Force_Stress_LCAO(Record_adj& ra, const int nat_in) : RA(&ra), f_pw(nat_in), nat(nat_in)
25+
Force_Stress_LCAO<T>::Force_Stress_LCAO(Record_adj& ra, const int nat_in) : RA(&ra), nat(nat_in)
2626
{
2727
}
2828
template <typename T>
@@ -861,24 +861,39 @@ void Force_Stress_LCAO<T>::calForcePwPart(UnitCell& ucell,
861861
const Structure_Factor& sf)
862862
{
863863
ModuleBase::TITLE("Force_Stress_LCAO", "calForcePwPart");
864-
//--------------------------------------------------------
865-
// local pseudopotential force:
866-
// use charge density; plane wave; local pseudopotential;
867-
//--------------------------------------------------------
868-
f_pw.cal_force_loc(ucell, fvl_dvl, rhopw, locpp.vloc, chr);
869-
//--------------------------------------------------------
870-
// ewald force: use plane wave only.
871-
//--------------------------------------------------------
872-
f_pw.cal_force_ew(ucell, fewalds, rhopw, &sf); // remain problem
864+
#ifdef __CUDA
865+
if(PARAM.inp.device == "gpu")
866+
{
867+
Forces<double, base_device::DEVICE_GPU> f_pw(nat);
868+
869+
//--------------------------------------------------------
870+
// local pseudopotential force:
871+
// use charge density; plane wave; local pseudopotential;
872+
//--------------------------------------------------------
873+
f_pw.cal_force_loc(ucell, fvl_dvl, rhopw, locpp.vloc, chr);
874+
//--------------------------------------------------------
875+
// ewald force: use plane wave only.
876+
//--------------------------------------------------------
877+
f_pw.cal_force_ew(ucell, fewalds, rhopw, &sf); // remain problem
878+
879+
//--------------------------------------------------------
880+
// force due to core correlation.
881+
//--------------------------------------------------------
882+
f_pw.cal_force_cc(fcc, rhopw, chr, locpp.numeric, ucell);
883+
//--------------------------------------------------------
884+
// force due to self-consistent charge.
885+
//--------------------------------------------------------
886+
f_pw.cal_force_scc(fscc, rhopw, vnew, vnew_exist, locpp.numeric, ucell);
887+
} else
888+
#endif
889+
{
890+
Forces<double, base_device::DEVICE_CPU> f_pw(nat);
891+
f_pw.cal_force_loc(ucell, fvl_dvl, rhopw, locpp.vloc, chr);
892+
f_pw.cal_force_ew(ucell, fewalds, rhopw, &sf); // remain problem
893+
f_pw.cal_force_cc(fcc, rhopw, chr, locpp.numeric, ucell);
894+
f_pw.cal_force_scc(fscc, rhopw, vnew, vnew_exist, locpp.numeric, ucell);
895+
}
873896

874-
//--------------------------------------------------------
875-
// force due to core correlation.
876-
//--------------------------------------------------------
877-
f_pw.cal_force_cc(fcc, rhopw, chr, locpp.numeric, ucell);
878-
//--------------------------------------------------------
879-
// force due to self-consistent charge.
880-
//--------------------------------------------------------
881-
f_pw.cal_force_scc(fscc, rhopw, vnew, vnew_exist, locpp.numeric, ucell);
882897
return;
883898
}
884899

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class Force_Stress_LCAO
6060
Record_adj* RA;
6161
Force_LCAO<T> flk;
6262
Stress_Func<double> sc_pw;
63-
Forces<double> f_pw;
6463

6564
void forceSymmetry(const UnitCell& ucell,
6665
ModuleBase::matrix& fcs,

0 commit comments

Comments
 (0)