Skip to content

Commit f3b5cc4

Browse files
committed
add template for precision
1 parent 79fae46 commit f3b5cc4

File tree

8 files changed

+100
-89
lines changed

8 files changed

+100
-89
lines changed

source/module_elecstate/elecstate_pw_sdft.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
namespace elecstate
99
{
1010

11-
template <typename Device>
12-
void ElecStatePW_SDFT<Device>::psiToRho(const psi::Psi<std::complex<double>>& psi)
11+
template <typename T, typename Device>
12+
void ElecStatePW_SDFT<T, Device>::psiToRho(const psi::Psi<T>& psi)
1313
{
1414
ModuleBase::TITLE(this->classname, "psiToRho");
1515
ModuleBase::timer::tick(this->classname, "psiToRho");
@@ -42,5 +42,6 @@ void ElecStatePW_SDFT<Device>::psiToRho(const psi::Psi<std::complex<double>>& ps
4242
return;
4343
}
4444

45-
template class ElecStatePW_SDFT<base_device::DEVICE_CPU>;
45+
// template class ElecStatePW_SDFT<std::complex<float>, base_device::DEVICE_CPU>;
46+
template class ElecStatePW_SDFT<std::complex<double>, base_device::DEVICE_CPU>;
4647
} // namespace elecstate

source/module_elecstate/elecstate_pw_sdft.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include "elecstate_pw.h"
44
namespace elecstate
55
{
6-
template <typename Device>
7-
class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>, Device>
6+
template <typename T, typename Device>
7+
class ElecStatePW_SDFT : public ElecStatePW<T, Device>
88
{
99
public:
1010
ElecStatePW_SDFT(ModulePW::PW_Basis_K* wfc_basis_in,
@@ -15,12 +15,12 @@ class ElecStatePW_SDFT : public ElecStatePW<std::complex<double>, Device>
1515
ModulePW::PW_Basis* rhodpw_in,
1616
ModulePW::PW_Basis* rhopw_in,
1717
ModulePW::PW_Basis_Big* bigpw_in)
18-
: ElecStatePW<std::complex<double>,
18+
: ElecStatePW<T,
1919
Device>(wfc_basis_in, chg_in, pkv_in, ucell_in, ppcell_in, rhodpw_in, rhopw_in, bigpw_in)
2020
{
2121
this->classname = "ElecStatePW_SDFT";
2222
}
23-
virtual void psiToRho(const psi::Psi<std::complex<double>>& psi) override;
23+
virtual void psiToRho(const psi::Psi<T>& psi) override;
2424
};
2525
} // namespace elecstate
2626
#endif

source/module_esolver/esolver.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,14 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
155155
}
156156
else if (esolver_type == "sdft_pw")
157157
{
158-
return new ESolver_SDFT_PW<base_device::DEVICE_CPU>();
158+
// if (PARAM.inp.precision == "single")
159+
// {
160+
// return new ESolver_SDFT_PW<std::complex<float>, base_device::DEVICE_CPU>();
161+
// }
162+
// else
163+
// {
164+
return new ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>();
165+
// }
159166
}
160167
#ifdef __LCAO
161168
else if (esolver_type == "ksdft_lip")

source/module_esolver/esolver_sdft_pw.cpp

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -29,31 +29,31 @@
2929
namespace ModuleESolver
3030
{
3131

32-
template <typename Device>
33-
ESolver_SDFT_PW<Device>::ESolver_SDFT_PW()
32+
template <typename T, typename Device>
33+
ESolver_SDFT_PW<T, Device>::ESolver_SDFT_PW()
3434
: stoche(PARAM.inp.nche_sto, PARAM.inp.method_sto, PARAM.inp.emax_sto, PARAM.inp.emin_sto)
3535
{
3636
this->classname = "ESolver_SDFT_PW";
3737
this->basisname = "PW";
3838
}
3939

40-
template <typename Device>
41-
ESolver_SDFT_PW<Device>::~ESolver_SDFT_PW()
40+
template <typename T, typename Device>
41+
ESolver_SDFT_PW<T, Device>::~ESolver_SDFT_PW()
4242
{
4343
}
4444

45-
template <typename Device>
46-
void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
45+
template <typename T, typename Device>
46+
void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
4747
{
4848
// 1) initialize parameters from int Input class
4949
this->nche_sto = inp.nche_sto;
5050
this->method_sto = inp.method_sto;
5151

5252
// 2) run "before_all_runners" in ESolver_KS
53-
ESolver_KS<std::complex<double>, Device>::before_all_runners(inp, ucell);
53+
ESolver_KS<T, Device>::before_all_runners(inp, ucell);
5454

5555
// 3) initialize the pointer for electronic states of SDFT
56-
this->pelec = new elecstate::ElecStatePW_SDFT<Device>(this->pw_wfc,
56+
this->pelec = new elecstate::ElecStatePW_SDFT<T, Device>(this->pw_wfc,
5757
&(this->chr),
5858
&(this->kv),
5959
&ucell,
@@ -79,7 +79,7 @@ void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell
7979
}
8080

8181
// 6) prepare some parameters for electronic wave functions initilization
82-
this->p_wf_init = new psi::WFInit<std::complex<double>, Device>(PARAM.inp.init_wfc,
82+
this->p_wf_init = new psi::WFInit<T, Device>(PARAM.inp.init_wfc,
8383
PARAM.inp.ks_solver,
8484
PARAM.inp.basis_type,
8585
PARAM.inp.psi_initializer,
@@ -118,60 +118,60 @@ void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell
118118

119119
size_t size = stowf.chi0->size();
120120

121-
this->stowf.shchi = new psi::Psi<std::complex<double>>(this->kv.get_nks(),
121+
this->stowf.shchi = new psi::Psi<T>(this->kv.get_nks(),
122122
this->stowf.nchip_max,
123123
this->wf.npwx,
124124
this->kv.ngk.data());
125125

126-
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(std::complex<double>));
126+
ModuleBase::Memory::record("SDFT::shchi", size * sizeof(T));
127127

128128
if (PARAM.inp.nbands > 0)
129129
{
130-
this->stowf.chiortho = new psi::Psi<std::complex<double>>(this->kv.get_nks(),
130+
this->stowf.chiortho = new psi::Psi<T>(this->kv.get_nks(),
131131
this->stowf.nchip_max,
132132
this->wf.npwx,
133133
this->kv.ngk.data());
134-
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(std::complex<double>));
134+
ModuleBase::Memory::record("SDFT::chiortho", size * sizeof(T));
135135
}
136136

137137
return;
138138
}
139139

140-
template <typename Device>
141-
void ESolver_SDFT_PW<Device>::before_scf(const int istep)
140+
template <typename T, typename Device>
141+
void ESolver_SDFT_PW<T, Device>::before_scf(const int istep)
142142
{
143-
ESolver_KS_PW<std::complex<double>, Device>::before_scf(istep);
143+
ESolver_KS_PW<T, Device>::before_scf(istep);
144144
delete reinterpret_cast<hamilt::HamiltPW<double>*>(this->p_hamilt);
145-
this->p_hamilt = new hamilt::HamiltSdftPW<std::complex<double>, Device>(this->pelec->pot,
145+
this->p_hamilt = new hamilt::HamiltSdftPW<T, Device>(this->pelec->pot,
146146
this->pw_wfc,
147147
&this->kv,
148148
PARAM.globalv.npol,
149149
&this->stoche.emin_sto,
150150
&this->stoche.emax_sto);
151-
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<std::complex<double>, Device>*>(this->p_hamilt);
151+
this->p_hamilt_sto = static_cast<hamilt::HamiltSdftPW<T, Device>*>(this->p_hamilt);
152152

153153
if (istep > 0 && PARAM.inp.nbands_sto != 0 && PARAM.inp.initsto_freq > 0 && istep % PARAM.inp.initsto_freq == 0)
154154
{
155155
Update_Sto_Orbitals(this->stowf, PARAM.inp.seed_sto);
156156
}
157157
}
158158

159-
template <typename Device>
160-
void ESolver_SDFT_PW<Device>::iter_finish(int& iter)
159+
template <typename T, typename Device>
160+
void ESolver_SDFT_PW<T, Device>::iter_finish(int& iter)
161161
{
162162
// call iter_finish() of ESolver_KS
163-
ESolver_KS<std::complex<double>, Device>::iter_finish(iter);
163+
ESolver_KS<T, Device>::iter_finish(iter);
164164
}
165165

166-
template <typename Device>
167-
void ESolver_SDFT_PW<Device>::after_scf(const int istep)
166+
template <typename T, typename Device>
167+
void ESolver_SDFT_PW<T, Device>::after_scf(const int istep)
168168
{
169169
// 1) call after_scf() of ESolver_KS_PW
170-
ESolver_KS_PW<std::complex<double>, Device>::after_scf(istep);
170+
ESolver_KS_PW<T, Device>::after_scf(istep);
171171
}
172172

173-
template <typename Device>
174-
void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
173+
template <typename T, typename Device>
174+
void ESolver_SDFT_PW<T, Device>::hamilt2density(int istep, int iter, double ethr)
175175
{
176176
// reset energy
177177
this->pelec->f_en.eband = 0.0;
@@ -180,19 +180,19 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
180180
// be careful that istep start from 0 and iter start from 1
181181
if (istep == 0 && iter == 1)
182182
{
183-
hsolver::DiagoIterAssist<std::complex<double>, Device>::need_subspace = false;
183+
hsolver::DiagoIterAssist<T, Device>::need_subspace = false;
184184
}
185185
else
186186
{
187-
hsolver::DiagoIterAssist<std::complex<double>, Device>::need_subspace = true;
187+
hsolver::DiagoIterAssist<T, Device>::need_subspace = true;
188188
}
189189

190-
hsolver::DiagoIterAssist<std::complex<double>, Device>::PW_DIAG_THR = ethr;
190+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR = ethr;
191191

192-
hsolver::DiagoIterAssist<std::complex<double>, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
192+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX = PARAM.inp.pw_diag_nmax;
193193

194194
// hsolver only exists in this function
195-
hsolver::HSolverPW_SDFT<Device> hsolver_pw_sdft_obj(
195+
hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj(
196196
&this->kv,
197197
this->pw_wfc,
198198
&this->wf,
@@ -205,10 +205,10 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
205205
PARAM.inp.use_paw,
206206
PARAM.globalv.use_uspp,
207207
PARAM.inp.nspin,
208-
hsolver::DiagoIterAssist<std::complex<double>, Device>::SCF_ITER,
209-
hsolver::DiagoIterAssist<std::complex<double>, Device>::PW_DIAG_NMAX,
210-
hsolver::DiagoIterAssist<std::complex<double>, Device>::PW_DIAG_THR,
211-
hsolver::DiagoIterAssist<std::complex<double>, Device>::need_subspace,
208+
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
209+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
210+
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
211+
hsolver::DiagoIterAssist<T, Device>::need_subspace,
212212
this->init_psi);
213213

214214
hsolver_pw_sdft_obj.solve(this->p_hamilt, this->psi[0], this->pelec, this->pw_wfc, this->stowf, istep, iter, false);
@@ -240,14 +240,14 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
240240
#endif
241241
}
242242

243-
template <typename Device>
244-
double ESolver_SDFT_PW<Device>::cal_energy()
243+
template <typename T, typename Device>
244+
double ESolver_SDFT_PW<T, Device>::cal_energy()
245245
{
246246
return this->pelec->f_en.etot;
247247
}
248248

249-
template <typename Device>
250-
void ESolver_SDFT_PW<Device>::cal_force(ModuleBase::matrix& force)
249+
template <typename T, typename Device>
250+
void ESolver_SDFT_PW<T, Device>::cal_force(ModuleBase::matrix& force)
251251
{
252252
Sto_Forces ff(GlobalC::ucell.nat);
253253

@@ -262,8 +262,8 @@ void ESolver_SDFT_PW<Device>::cal_force(ModuleBase::matrix& force)
262262
this->stowf);
263263
}
264264

265-
template <typename Device>
266-
void ESolver_SDFT_PW<Device>::cal_stress(ModuleBase::matrix& stress)
265+
template <typename T, typename Device>
266+
void ESolver_SDFT_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
267267
{
268268
Sto_Stress_PW ss;
269269
ss.cal_stress(stress,
@@ -280,8 +280,8 @@ void ESolver_SDFT_PW<Device>::cal_stress(ModuleBase::matrix& stress)
280280
GlobalC::ucell);
281281
}
282282

283-
template <typename Device>
284-
void ESolver_SDFT_PW<Device>::after_all_runners()
283+
template <typename T, typename Device>
284+
void ESolver_SDFT_PW<T, Device>::after_all_runners()
285285
{
286286
GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl;
287287
GlobalV::ofs_running << std::setprecision(16);
@@ -291,7 +291,7 @@ void ESolver_SDFT_PW<Device>::after_all_runners()
291291
}
292292

293293
template <>
294-
void ESolver_SDFT_PW<base_device::DEVICE_CPU>::after_all_runners()
294+
void ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>::after_all_runners()
295295
{
296296

297297
GlobalV::ofs_running << "\n\n --------------------------------------------" << std::endl;
@@ -341,8 +341,8 @@ void ESolver_SDFT_PW<base_device::DEVICE_CPU>::after_all_runners()
341341
}
342342
}
343343

344-
template <typename Device>
345-
void ESolver_SDFT_PW<Device>::others(const int istep)
344+
template <typename T, typename Device>
345+
void ESolver_SDFT_PW<T, Device>::others(const int istep)
346346
{
347347
ModuleBase::TITLE("ESolver_SDFT_PW", "others");
348348

@@ -352,14 +352,14 @@ void ESolver_SDFT_PW<Device>::others(const int istep)
352352
}
353353
else
354354
{
355-
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<Device>::others", "CALCULATION type not supported");
355+
ModuleBase::WARNING_QUIT("ESolver_SDFT_PW<T, Device>::others", "CALCULATION type not supported");
356356
}
357357

358358
return;
359359
}
360360

361-
template <typename Device>
362-
void ESolver_SDFT_PW<Device>::nscf()
361+
template <typename T, typename Device>
362+
void ESolver_SDFT_PW<T, Device>::nscf()
363363
{
364364
ModuleBase::TITLE("ESolver_SDFT_PW", "nscf");
365365
ModuleBase::timer::tick("ESolver_SDFT_PW", "nscf");
@@ -382,5 +382,6 @@ void ESolver_SDFT_PW<Device>::nscf()
382382
return;
383383
}
384384

385-
template class ESolver_SDFT_PW<base_device::DEVICE_CPU>;
385+
// template class ESolver_SDFT_PW<std::complex<float>, base_device::DEVICE_CPU>;
386+
template class ESolver_SDFT_PW<std::complex<double>, base_device::DEVICE_CPU>;
386387
} // namespace ModuleESolver

source/module_esolver/esolver_sdft_pw.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
namespace ModuleESolver
1111
{
1212

13-
template <typename Device>
14-
class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>, Device>
13+
template <typename T, typename Device>
14+
class ESolver_SDFT_PW : public ESolver_KS_PW<T, Device>
1515
{
1616
public:
1717
ESolver_SDFT_PW();
@@ -28,7 +28,7 @@ class ESolver_SDFT_PW : public ESolver_KS_PW<std::complex<double>, Device>
2828
public:
2929
Stochastic_WF stowf;
3030
StoChe<double> stoche;
31-
hamilt::HamiltSdftPW<std::complex<double>, Device>* p_hamilt_sto = nullptr;
31+
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto = nullptr;
3232

3333
protected:
3434
virtual void before_scf(const int istep) override;

source/module_hsolver/hsolver_pw_sdft.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
namespace hsolver
1111
{
12-
template <typename Device>
13-
void HSolverPW_SDFT<Device>::solve(hamilt::Hamilt<std::complex<double>, Device>* pHamilt,
14-
psi::Psi<std::complex<double>, Device>& psi,
12+
template <typename T, typename Device>
13+
void HSolverPW_SDFT<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
14+
psi::Psi<T, Device>& psi,
1515
elecstate::ElecState* pes,
1616
ModulePW::PW_Basis_K* wfc_basis,
1717
Stochastic_WF& stowf,
@@ -107,5 +107,6 @@ void HSolverPW_SDFT<Device>::solve(hamilt::Hamilt<std::complex<double>, Device>*
107107
return;
108108
}
109109

110-
template class HSolverPW_SDFT<base_device::DEVICE_CPU>;
110+
// template class HSolverPW_SDFT<std::complex<float>, base_device::DEVICE_CPU>;
111+
template class HSolverPW_SDFT<std::complex<double>, base_device::DEVICE_CPU>;
111112
} // namespace hsolver

source/module_hsolver/hsolver_pw_sdft.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
#include "module_hamilt_pw/hamilt_stodft/sto_iter.h"
66
namespace hsolver
77
{
8-
template <typename Device>
9-
class HSolverPW_SDFT : public HSolverPW<std::complex<double>, Device>
8+
template <typename T, typename Device>
9+
class HSolverPW_SDFT : public HSolverPW<T, Device>
1010
{
1111
public:
1212
HSolverPW_SDFT(K_Vectors* pkv,
1313
ModulePW::PW_Basis_K* wfc_basis_in,
1414
wavefunc* pwf_in,
1515
Stochastic_WF& stowf,
1616
StoChe<double>& stoche,
17-
hamilt::HamiltSdftPW<std::complex<double>, Device>* p_hamilt_sto,
17+
hamilt::HamiltSdftPW<T, Device>* p_hamilt_sto,
1818
const std::string calculation_type_in,
1919
const std::string basis_type_in,
2020
const std::string method_in,
@@ -26,7 +26,7 @@ class HSolverPW_SDFT : public HSolverPW<std::complex<double>, Device>
2626
const double diag_thr_in,
2727
const bool need_subspace_in,
2828
const bool initialed_psi_in)
29-
: HSolverPW<std::complex<double>, Device>(wfc_basis_in,
29+
: HSolverPW<T, Device>(wfc_basis_in,
3030
pwf_in,
3131
calculation_type_in,
3232
basis_type_in,
@@ -43,8 +43,8 @@ class HSolverPW_SDFT : public HSolverPW<std::complex<double>, Device>
4343
stoiter.init(pkv, wfc_basis_in, stowf, stoche, p_hamilt_sto);
4444
}
4545

46-
void solve(hamilt::Hamilt<std::complex<double>, Device>* pHamilt,
47-
psi::Psi<std::complex<double>, Device>& psi,
46+
void solve(hamilt::Hamilt<T, Device>* pHamilt,
47+
psi::Psi<T, Device>& psi,
4848
elecstate::ElecState* pes,
4949
ModulePW::PW_Basis_K* wfc_basis,
5050
Stochastic_WF& stowf,

0 commit comments

Comments
 (0)