Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/module_elecstate/potentials/potential_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Potential::~Potential()
}
this->components.clear();
}
if (PARAM.inp.device == "gpu") {
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
if (PARAM.inp.precision == "single") {
delmem_sd_op()(gpu_ctx, s_veff_smooth);
delmem_sd_op()(gpu_ctx, s_vofk_smooth);
Expand Down Expand Up @@ -129,7 +129,7 @@ void Potential::allocate()
this->vofk_smooth.create(PARAM.inp.nspin, nrxx_smooth);
ModuleBase::Memory::record("Pot::vofk_smooth", sizeof(double) * PARAM.inp.nspin * nrxx_smooth);
}
if (PARAM.inp.device == "gpu") {
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
if (PARAM.inp.precision == "single") {
resmem_sd_op()(gpu_ctx, s_veff_smooth, PARAM.inp.nspin * nrxx_smooth);
resmem_sd_op()(gpu_ctx, s_vofk_smooth, PARAM.inp.nspin * nrxx_smooth);
Expand Down Expand Up @@ -177,7 +177,7 @@ void Potential::update_from_charge(const Charge*const chg, const UnitCell*const
}
#endif

if (PARAM.inp.device == "gpu") {
if (PARAM.inp.basis_type == "pw" && PARAM.inp.device == "gpu") {
if (PARAM.inp.precision == "single") {
castmem_d2s_h2d_op()(gpu_ctx,
cpu_ctx,
Expand Down
11 changes: 8 additions & 3 deletions source/module_esolver/esolver_fp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ namespace ModuleESolver
ESolver_FP::ESolver_FP()
{
// pw_rho = new ModuleBase::PW_Basis();
pw_rho = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision);

// LCAO basis doesn't support GPU acceleration on FFT currently
std::string fft_device = PARAM.inp.device;
if(PARAM.inp.basis_type == "lcao")
{
fft_device = "cpu";
}
pw_rho = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
if ( PARAM.globalv.double_grid)
{
pw_rhod = new ModulePW::PW_Basis_Big(PARAM.inp.device, PARAM.inp.precision);
pw_rhod = new ModulePW::PW_Basis_Big(fft_device, PARAM.inp.precision);
}
else
{
Expand Down
8 changes: 7 additions & 1 deletion source/module_esolver/esolver_ks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ ESolver_KS<T, Device>::ESolver_KS()

// pw_rho = new ModuleBase::PW_Basis();
// temporary, it will be removed
pw_wfc = new ModulePW::PW_Basis_K_Big(PARAM.inp.device, PARAM.inp.precision);
std::string fft_device = PARAM.inp.device;
// LCAO basis doesn't support GPU acceleration on FFT currently
if(PARAM.inp.basis_type == "lcao")
{
fft_device = "cpu";
}
pw_wfc = new ModulePW::PW_Basis_K_Big(fft_device, PARAM.inp.precision);
ModulePW::PW_Basis_K_Big* tmp = static_cast<ModulePW::PW_Basis_K_Big*>(pw_wfc);

// should not use INPUT here, mohan 2024-05-12
Expand Down
10 changes: 7 additions & 3 deletions source/module_hamilt_pw/hamilt_pwdft/structure_factor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

Structure_Factor::Structure_Factor()
{

// LCAO basis doesn't support GPU acceleration on this function currently.
if(PARAM.inp.basis_type == "pw")
{
this->device = PARAM.inp.device;
}
}

Structure_Factor::~Structure_Factor()
{
if (PARAM.inp.device == "gpu") {
if (device == "gpu") {
if (PARAM.inp.precision == "single") {
delmem_cd_op()(gpu_ctx, this->c_eigts1);
delmem_cd_op()(gpu_ctx, this->c_eigts2);
Expand Down Expand Up @@ -145,7 +149,7 @@ void Structure_Factor::setup_structure_factor(UnitCell* Ucell, const ModulePW::P
inat++;
}
}
if (PARAM.inp.device == "gpu") {
if (device == "gpu") {
if (PARAM.inp.precision == "single") {
resmem_cd_op()(gpu_ctx, this->c_eigts1, Ucell->nat * (2 * rho_basis->nx + 1));
resmem_cd_op()(gpu_ctx, this->c_eigts2, Ucell->nat * (2 * rho_basis->ny + 1));
Expand Down
1 change: 1 addition & 0 deletions source/module_hamilt_pw/hamilt_pwdft/structure_factor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ class Structure_Factor
std::complex<float> * c_eigts1 = nullptr, * c_eigts2 = nullptr, * c_eigts3 = nullptr;
std::complex<double> * z_eigts1 = nullptr, * z_eigts2 = nullptr, * z_eigts3 = nullptr;
const ModulePW::PW_Basis* rho_basis = nullptr;
std::string device = "cpu";
};
#endif //PlaneWave class
Loading