Skip to content

Commit cfffc78

Browse files
authored
remove unnecessary gpu memory allocation under LCAO basis (#5258)
1 parent 7938206 commit cfffc78

File tree

5 files changed

+26
-10
lines changed

5 files changed

+26
-10
lines changed

source/module_elecstate/potentials/potential_new.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ Potential::~Potential()
4646
}
4747
this->components.clear();
4848
}
49-
if (PARAM.inp.device == "gpu") {
49+
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
5050
if (PARAM.inp.precision == "single") {
5151
delmem_sd_op()(gpu_ctx, s_veff_smooth);
5252
delmem_sd_op()(gpu_ctx, s_vofk_smooth);
@@ -129,7 +129,7 @@ void Potential::allocate()
129129
this->vofk_smooth.create(PARAM.inp.nspin, nrxx_smooth);
130130
ModuleBase::Memory::record("Pot::vofk_smooth", sizeof(double) * PARAM.inp.nspin * nrxx_smooth);
131131
}
132-
if (PARAM.inp.device == "gpu") {
132+
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
133133
if (PARAM.inp.precision == "single") {
134134
resmem_sd_op()(gpu_ctx, s_veff_smooth, PARAM.inp.nspin * nrxx_smooth);
135135
resmem_sd_op()(gpu_ctx, s_vofk_smooth, PARAM.inp.nspin * nrxx_smooth);
@@ -177,7 +177,7 @@ void Potential::update_from_charge(const Charge*const chg, const UnitCell*const
177177
}
178178
#endif
179179

180-
if (PARAM.inp.device == "gpu") {
180+
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
181181
if (PARAM.inp.precision == "single") {
182182
castmem_d2s_h2d_op()(gpu_ctx,
183183
cpu_ctx,

source/module_esolver/esolver_fp.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,16 @@ namespace ModuleESolver
1717
ESolver_FP::ESolver_FP()
1818
{
1919
// pw_rho = new ModuleBase::PW_Basis();
20-
pw_rho = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision);
21-
20+
// LCAO basis doesn't support GPU acceleration on FFT currently
21+
std::string fft_device = PARAM.inp.device;
22+
if(PARAM.inp.basis_type == "lcao")
23+
{
24+
fft_device = "cpu";
25+
}
26+
pw_rho = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
2227
if ( PARAM.globalv.double_grid)
2328
{
24-
pw_rhod = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision);
29+
pw_rhod = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
2530
}
2631
else
2732
{

source/module_esolver/esolver_ks.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,13 @@ ESolver_KS<T, Device>::ESolver_KS()
5757

5858
// pw_rho = new ModuleBase::PW_Basis();
5959
// temporary, it will be removed
60-
pw_wfc = new ModulePW::PW_Basis_K_Big(PARAM.inp.device, PARAM.inp.precision);
60+
std::string fft_device = PARAM.inp.device;
61+
// LCAO basis doesn't support GPU acceleration on FFT currently
62+
if(PARAM.inp.basis_type == "lcao")
63+
{
64+
fft_device = "cpu";
65+
}
66+
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, PARAM.inp.precision);
6167
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);
6268

6369
// should not use INPUT here, mohan 2024-05-12

source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,16 @@
1616

1717
Structure_Factor::Structure_Factor()
1818
{
19-
19+
// LCAO basis doesn't support GPU acceleration on this function currently.
20+
if(PARAM.inp.basis_type == "pw")
21+
{
22+
this->device = PARAM.inp.device;
23+
}
2024
}
2125

2226
Structure_Factor::~Structure_Factor()
2327
{
24-
if (PARAM.inp.device == "gpu") {
28+
if (device == "gpu") {
2529
if (PARAM.inp.precision == "single") {
2630
delmem_cd_op()(gpu_ctx, this->c_eigts1);
2731
delmem_cd_op()(gpu_ctx, this->c_eigts2);
@@ -145,7 +149,7 @@ void Structure_Factor::setup_structure_factor(UnitCell* Ucell, const ModulePW::P
145149
inat++;
146150
}
147151
}
148-
if (PARAM.inp.device == "gpu") {
152+
if (device == "gpu") {
149153
if (PARAM.inp.precision == "single") {
150154
resmem_cd_op()(gpu_ctx, this->c_eigts1, Ucell->nat * (2 * rho_basis->nx + 1));
151155
resmem_cd_op()(gpu_ctx, this->c_eigts2, Ucell->nat * (2 * rho_basis->ny + 1));

source/module_hamilt_pw/hamilt_pwdft/structure_factor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,6 @@ class Structure_Factor
5656
std::complex<float> * c_eigts1 = nullptr, * c_eigts2 = nullptr, * c_eigts3 = nullptr;
5757
std::complex<double> * z_eigts1 = nullptr, * z_eigts2 = nullptr, * z_eigts3 = nullptr;
5858
const ModulePW::PW_Basis* rho_basis = nullptr;
59+
std::string device = "cpu";
5960
};
6061
#endif //PlaneWave class

0 commit comments

Comments
 (0)