diff --git a/source/module_elecstate/potentials/potential_new.cpp b/source/module_elecstate/potentials/potential_new.cpp index e90d080fd4..48b90f65b4 100644 --- a/source/module_elecstate/potentials/potential_new.cpp +++ b/source/module_elecstate/potentials/potential_new.cpp @@ -46,7 +46,7 @@ Potential::~Potential() } this->components.clear(); } - if (PARAM.inp.device == "gpu") { + if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") { if (PARAM.inp.precision == "single") { delmem_sd_op()(gpu_ctx, s_veff_smooth); delmem_sd_op()(gpu_ctx, s_vofk_smooth); @@ -129,7 +129,7 @@ void Potential::allocate() this->vofk_smooth.create(PARAM.inp.nspin, nrxx_smooth); ModuleBase::Memory::record("Pot::vofk_smooth", sizeof(double) * PARAM.inp.nspin * nrxx_smooth); } - if (PARAM.inp.device == "gpu") { + if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") { if (PARAM.inp.precision == "single") { resmem_sd_op()(gpu_ctx, s_veff_smooth, PARAM.inp.nspin * nrxx_smooth); resmem_sd_op()(gpu_ctx, s_vofk_smooth, PARAM.inp.nspin * nrxx_smooth); @@ -177,7 +177,7 @@ void Potential::update_from_charge(const Charge*const chg, const UnitCell*const } #endif - if (PARAM.inp.device == "gpu") { + if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") { if (PARAM.inp.precision == "single") { castmem_d2s_h2d_op()(gpu_ctx, cpu_ctx, diff --git a/source/module_esolver/esolver_fp.cpp b/source/module_esolver/esolver_fp.cpp index 7bb5c83f26..4b87337235 100644 --- a/source/module_esolver/esolver_fp.cpp +++ b/source/module_esolver/esolver_fp.cpp @@ -17,11 +17,16 @@ namespace ModuleESolver ESolver_FP::ESolver_FP() { // pw_rho = new ModuleBase::PW_Basis(); - pw_rho = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision); - + // LCAO basis doesn't support GPU acceleration on FFT currently + std::string fft_device = PARAM.inp.device; + if(PARAM.inp.basis_type == "lcao") + { + fft_device = "cpu"; + } + pw_rho = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision); if ( PARAM.globalv.double_grid) { - pw_rhod = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision); + pw_rhod = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision); } else { diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 78163a9301..f57ff84d53 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -57,7 +57,13 @@ ESolver_KS::ESolver_KS() // pw_rho = new ModuleBase::PW_Basis(); // temporary, it will be removed - pw_wfc = new ModulePW::PW_Basis_K_Big(PARAM.inp.device, PARAM.inp.precision); + std::string fft_device = PARAM.inp.device; + // LCAO basis doesn't support GPU acceleration on FFT currently + if(PARAM.inp.basis_type == "lcao") + { + fft_device = "cpu"; + } + pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, PARAM.inp.precision); ModulePW::PW_Basis_K_Big* tmp = static_cast(pw_wfc); // should not use INPUT here, mohan 2024-05-12 diff --git a/source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp b/source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp index 163d6ca210..dd3663cbd0 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp @@ -16,12 +16,16 @@ Structure_Factor::Structure_Factor() { - + // LCAO basis doesn't support GPU acceleration on this function currently. + if(PARAM.inp.basis_type == "pw") + { + this->device = PARAM.inp.device; + } } Structure_Factor::~Structure_Factor() { - if (PARAM.inp.device == "gpu") { + if (device == "gpu") { if (PARAM.inp.precision == "single") { delmem_cd_op()(gpu_ctx, this->c_eigts1); delmem_cd_op()(gpu_ctx, this->c_eigts2); @@ -145,7 +149,7 @@ void Structure_Factor::setup_structure_factor(UnitCell* Ucell, const ModulePW::P inat++; } } - if (PARAM.inp.device == "gpu") { + if (device == "gpu") { if (PARAM.inp.precision == "single") { resmem_cd_op()(gpu_ctx, this->c_eigts1, Ucell->nat * (2 * rho_basis->nx + 1)); resmem_cd_op()(gpu_ctx, this->c_eigts2, Ucell->nat * (2 * rho_basis->ny + 1)); diff --git a/source/module_hamilt_pw/hamilt_pwdft/structure_factor.h b/source/module_hamilt_pw/hamilt_pwdft/structure_factor.h index 23b08c253e..5fc4e8a129 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/structure_factor.h +++ b/source/module_hamilt_pw/hamilt_pwdft/structure_factor.h @@ -56,5 +56,6 @@ class Structure_Factor std::complex * c_eigts1 = nullptr, * c_eigts2 = nullptr, * c_eigts3 = nullptr; std::complex * z_eigts1 = nullptr, * z_eigts2 = nullptr, * z_eigts3 = nullptr; const ModulePW::PW_Basis* rho_basis = nullptr; + std::string device = "cpu"; }; #endif //PlaneWave class