Skip to content

Commit 2d5c9c4

Browse files
committed
Make init stochastic WF support GPU
1 parent d6a922b commit 2d5c9c4

21 files changed

+764
-628
lines changed

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 46 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
5454

5555
// 3) initialize the pointer for electronic states of SDFT
5656
this->pelec = new elecstate::ElecStatePW_SDFT<T, Device>(this->pw_wfc,
57-
&(this->chr),
58-
&(this->kv),
59-
&ucell,
60-
&(GlobalC::ppcell),
61-
this->pw_rhod,
62-
this->pw_rho,
63-
this->pw_big);
57+
&(this->chr),
58+
&(this->kv),
59+
&ucell,
60+
&(GlobalC::ppcell),
61+
this->pw_rhod,
62+
this->pw_rho,
63+
this->pw_big);
6464

6565
// 4) inititlize the charge density.
6666
this->pelec->charge->allocate(PARAM.inp.nspin);
@@ -80,11 +80,11 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
8080

8181
// 6) prepare some parameters for electronic wave functions initilization
8282
this->p_wf_init = new psi::WFInit<T, Device>(PARAM.inp.init_wfc,
83-
PARAM.inp.ks_solver,
84-
PARAM.inp.basis_type,
85-
PARAM.inp.psi_initializer,
86-
&this->wf,
87-
this->pw_wfc);
83+
PARAM.inp.ks_solver,
84+
PARAM.inp.basis_type,
85+
PARAM.inp.psi_initializer,
86+
&this->wf,
87+
this->pw_wfc);
8888
// 7) set occupatio, redundant?
8989
if (PARAM.inp.ocp)
9090
{
@@ -95,42 +95,38 @@ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitC
9595
this->Init_GlobalC(inp, ucell, GlobalC::ppcell); // temporary
9696

9797
// 9) initialize the stochastic wave functions
98-
stowf.init(&this->kv, this->pw_wfc->npwk_max);
98+
this->stowf.init(&this->kv, this->pw_wfc->npwk_max);
9999
if (inp.nbands_sto != 0)
100100
{
101101
if (inp.initsto_ecut < inp.ecutwfc)
102102
{
103-
Init_Sto_Orbitals(this->stowf, inp.seed_sto);
103+
this->stowf.init_sto_orbitals(inp.seed_sto);
104104
}
105105
else
106106
{
107-
Init_Sto_Orbitals_Ecut(this->stowf, inp.seed_sto, this->kv, *this->pw_wfc, inp.initsto_ecut);
107+
this->stowf.init_sto_orbitals_Ecut(inp.seed_sto, this->kv, *this->pw_wfc, inp.initsto_ecut);
108108
}
109109
}
110110
else
111111
{
112-
Init_Com_Orbitals(this->stowf);
112+
this->stowf.init_com_orbitals();
113113
}
114114
if (this->method_sto == 2)
115115
{
116-
stowf.allocate_chiallorder(this->nche_sto);
116+
this->stowf.allocate_chiallorder(this->nche_sto);
117117
}
118+
this->stowf.sync_chi0();
118119

120+
// 10) allocate spaces for \sqrt(f(H))|chi> and |\tilde{chi}>
119121
size_t size = stowf.chi0->size();
120-
121-
this->stowf.shchi = new psi::Psi<T>(this->kv.get_nks(),
122-
this->stowf.nchip_max,
123-
this->wf.npwx,
124-
this->kv.ngk.data());
125-
122+
this->stowf.shchi
123+
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data());
126124
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T));
127125

128126
if (PARAM.inp.nbands > 0)
129127
{
130-
this->stowf.chiortho = new psi::Psi<T>(this->kv.get_nks(),
131-
this->stowf.nchip_max,
132-
this->wf.npwx,
133-
this->kv.ngk.data());
128+
this->stowf.chiortho
129+
= new psi::Psi<T, Device>(this->kv.get_nks(), this->stowf.nchip_max, this->wf.npwx, this->kv.ngk.data());
134130
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T));
135131
}
136132

@@ -143,16 +139,16 @@ void ESolver_SDFT_PW<T, Device>::before_scf(const int istep)
143139
ESolver_KS_PW<T, Device>::before_scf(istep);
144140
delete reinterpret_cast<hamilt::HamiltPW<double>*>(this->p_hamilt);
145141
this->p_hamilt = new hamilt::HamiltSdftPW<T, Device>(this->pelec->pot,
146-
this->pw_wfc,
147-
&this->kv,
148-
PARAM.globalv.npol,
149-
&this->stoche.emin_sto,
150-
&this->stoche.emax_sto);
142+
this->pw_wfc,
143+
&this->kv,
144+
PARAM.globalv.npol,
145+
&this->stoche.emin_sto,
146+
&this->stoche.emax_sto);
151147
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<T, Device>*>(this->p_hamilt);
152148

153149
if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0)
154150
{
155-
Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto);
151+
this->stowf.update_sto_orbitals(PARAM.inp.seed_sto);
156152
}
157153
}
158154

@@ -192,24 +188,23 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density(int istep, int iter, double ethr
192188
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
193189

194190
// hsolver only exists in this function
195-
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(
196-
&this->kv,
197-
this->pw_wfc,
198-
&this->wf,
199-
this->stowf,
200-
this->stoche,
201-
this->p_hamilt_sto,
202-
PARAM.inp.calculation,
203-
PARAM.inp.basis_type,
204-
PARAM.inp.ks_solver,
205-
PARAM.inp.use_paw,
206-
PARAM.globalv.use_uspp,
207-
PARAM.inp.nspin,
208-
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
209-
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
210-
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
211-
hsolver::DiagoIterAssist<T, Device>::need_subspace,
212-
this->init_psi);
191+
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(&this->kv,
192+
this->pw_wfc,
193+
&this->wf,
194+
this->stowf,
195+
this->stoche,
196+
this->p_hamilt_sto,
197+
PARAM.inp.calculation,
198+
PARAM.inp.basis_type,
199+
PARAM.inp.ks_solver,
200+
PARAM.inp.use_paw,
201+
PARAM.globalv.use_uspp,
202+
PARAM.inp.nspin,
203+
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
204+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
205+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
206+
hsolver::DiagoIterAssist<T, Device>::need_subspace,
207+
this->init_psi);
213208

214209
hsolver_pw_sdft_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false);
215210
this->init_psi = true;

source/module_esolver/esolver_sdft_pw.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
#define ESOLVER_SDFT_PW_H
33

44
#include "esolver_ks_pw.h"
5+
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"
6+
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
57
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
68
#include "module_hamilt_pw/hamilt_stodft/sto_wf.h"
7-
#include "module_hamilt_pw/hamilt_stodft/sto_che.h"
8-
#include "module_hamilt_pw/hamilt_stodft/hamilt_sdft_pw.h"
99

1010
namespace ModuleESolver
1111
{
1212

13-
template <typename T, typename Device>
13+
template <typename T, typename Device = base_device::DEVICE_CPU>
1414
class ESolver_SDFT_PW : public ESolver_KS_PW<T, Device>
1515
{
1616
public:
@@ -26,7 +26,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<T, Device>
2626
void cal_stress(ModuleBase::matrix& stress) override;
2727

2828
public:
29-
Stochastic_WF stowf;
29+
Stochastic_WF<T, Device> stowf;
3030
StoChe<double> stoche;
3131
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto = nullptr;
3232

source/module_hamilt_pw/hamilt_stodft/sto_dos.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#include "sto_dos.h"
22

3-
#include "module_parameter/parameter.h"
43
#include "module_base/timer.h"
54
#include "module_base/tool_title.h"
5+
#include "module_parameter/parameter.h"
66
#include "sto_tool.h"
77
Sto_DOS::~Sto_DOS()
88
{
@@ -14,7 +14,7 @@ Sto_DOS::Sto_DOS(ModulePW::PW_Basis_K* p_wfcpw_in,
1414
psi::Psi<std::complex<double>>* p_psi_in,
1515
hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
1616
StoChe<double>& stoche,
17-
Stochastic_WF* p_stowf_in)
17+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf_in)
1818
{
1919
this->p_wfcpw = p_wfcpw_in;
2020
this->p_kv = p_kv_in;
@@ -38,13 +38,7 @@ void Sto_DOS::decide_param(const int& dos_nche,
3838
const double& dos_scale)
3939
{
4040
this->dos_nche = dos_nche;
41-
check_che(this->dos_nche,
42-
emin_sto,
43-
emax_sto,
44-
this->nbands_sto,
45-
this->p_kv,
46-
this->p_stowf,
47-
this->p_hamilt_sto);
41+
check_che(this->dos_nche, emin_sto, emax_sto, this->nbands_sto, this->p_kv, this->p_stowf, this->p_hamilt_sto);
4842
if (dos_setemax)
4943
{
5044
this->emax = dos_emax_ev;
@@ -147,12 +141,7 @@ void Sto_DOS::caldos(const double sigmain, const double de, const int npart)
147141
}
148142
ModuleBase::GlobalFunc::ZEROS(allorderchi.data(), nchipk_new * npwx * dos_nche);
149143
std::complex<double>* tmpchi = pchi + start_nchipk * npwx;
150-
che.calpolyvec_complex(hchi_norm,
151-
tmpchi,
152-
allorderchi.data(),
153-
npw,
154-
npwx,
155-
nchipk_new);
144+
che.calpolyvec_complex(hchi_norm, tmpchi, allorderchi.data(), npw, npwx, nchipk_new);
156145
double* vec_all = (double*)allorderchi.data();
157146
int LDA = npwx * nchipk_new * 2;
158147
int M = npwx * nchipk_new * 2;

source/module_hamilt_pw/hamilt_stodft/sto_dos.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Sto_DOS
1515
psi::Psi<std::complex<double>>* p_psi_in,
1616
hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
1717
StoChe<double>& stoche,
18-
Stochastic_WF* p_stowf_in);
18+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf_in);
1919
~Sto_DOS();
2020

2121
/**
@@ -59,8 +59,10 @@ class Sto_DOS
5959
elecstate::ElecState* p_elec = nullptr; ///< pointer to the electronic state
6060
psi::Psi<std::complex<double>>* p_psi = nullptr; ///< pointer to the wavefunction
6161
hamilt::Hamilt<std::complex<double>>* p_hamilt; ///< pointer to the Hamiltonian
62-
Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions
63-
Sto_Func<double> stofunc; ///< functions
62+
63+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf
64+
= nullptr; ///< pointer to the stochastic wavefunctions
65+
Sto_Func<double> stofunc; ///< functions
6466

6567
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT
6668
};

source/module_hamilt_pw/hamilt_stodft/sto_elecond.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#include "sto_elecond.h"
22

3-
#include "module_parameter/parameter.h"
43
#include "module_base/complexmatrix.h"
54
#include "module_base/constants.h"
65
#include "module_base/memory.h"
76
#include "module_base/timer.h"
87
#include "module_base/vector3.h"
8+
#include "module_parameter/parameter.h"
99
#include "sto_tool.h"
1010

1111
#include <chrono>
@@ -21,7 +21,7 @@ Sto_EleCond::Sto_EleCond(UnitCell* p_ucell_in,
2121
pseudopot_cell_vnl* p_ppcell_in,
2222
hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
2323
StoChe<double>& stoche,
24-
Stochastic_WF* p_stowf_in)
24+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf_in)
2525
: EleCond(p_ucell_in, p_kv_in, p_elec_in, p_wfcpw_in, p_psi_in, p_ppcell_in)
2626
{
2727
this->p_hamilt = p_hamilt_in;
@@ -386,9 +386,10 @@ void Sto_EleCond::cal_jmatrix(const psi::Psi<std::complex<float>>& kspsi_all,
386386

387387
remain -= tmpnb;
388388
startnb += tmpnb;
389-
if (remain == 0) {
389+
if (remain == 0)
390+
{
390391
break;
391-
}
392+
}
392393
}
393394

394395
for (int id = 0; id < ndim; ++id)

source/module_hamilt_pw/hamilt_stodft/sto_elecond.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Sto_EleCond : protected EleCond
1717
pseudopot_cell_vnl* p_ppcell_in,
1818
hamilt::Hamilt<std::complex<double>>* p_hamilt_in,
1919
StoChe<double>& stoche,
20-
Stochastic_WF* p_stowf_in);
20+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf_in);
2121
~Sto_EleCond(){};
2222
/**
2323
* @brief Set the N order of Chebyshev expansion for conductivities
@@ -59,8 +59,9 @@ class Sto_EleCond : protected EleCond
5959
int fd_nche = 0; ///< number of Chebyshev orders for Fermi-Dirac function
6060
int cond_dtbatch = 0; ///< number of time steps in a batch
6161
hamilt::Hamilt<std::complex<double>>* p_hamilt; ///< pointer to the Hamiltonian
62-
Stochastic_WF* p_stowf = nullptr; ///< pointer to the stochastic wavefunctions
63-
Sto_Func<double> stofunc; ///< functions
62+
Stochastic_WF<std::complex<double>, base_device::DEVICE_CPU>* p_stowf
63+
= nullptr; ///< pointer to the stochastic wavefunctions
64+
Sto_Func<double> stofunc; ///< functions
6465

6566
hamilt::HamiltSdftPW<std::complex<double>>* p_hamilt_sto = nullptr; ///< pointer to the Hamiltonian for sDFT
6667

0 commit comments

Comments
 (0)