2929namespace ModuleESolver
3030{
3131
32- template <typename Device>
33- ESolver_SDFT_PW<Device>::ESolver_SDFT_PW()
32+ template <typename T, typename Device>
33+ ESolver_SDFT_PW<T, Device>::ESolver_SDFT_PW()
3434 : stoche(PARAM.inp.nche_sto, PARAM.inp.method_sto, PARAM.inp.emax_sto, PARAM.inp.emin_sto)
3535{
3636 this ->classname = " ESolver_SDFT_PW" ;
3737 this ->basisname = " PW" ;
3838}
3939
40- template <typename Device>
41- ESolver_SDFT_PW<Device>::~ESolver_SDFT_PW ()
40+ template <typename T, typename Device>
41+ ESolver_SDFT_PW<T, Device>::~ESolver_SDFT_PW ()
4242{
4343}
4444
45- template <typename Device>
46- void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
45+ template <typename T, typename Device>
46+ void ESolver_SDFT_PW<T, Device>::before_all_runners(const Input_para& inp, UnitCell& ucell)
4747{
4848 // 1) initialize parameters from int Input class
4949 this ->nche_sto = inp.nche_sto ;
5050 this ->method_sto = inp.method_sto ;
5151
5252 // 2) run "before_all_runners" in ESolver_KS
53- ESolver_KS<std:: complex < double > , Device>::before_all_runners (inp, ucell);
53+ ESolver_KS<T , Device>::before_all_runners (inp, ucell);
5454
5555 // 3) initialize the pointer for electronic states of SDFT
56- this ->pelec = new elecstate::ElecStatePW_SDFT<Device>(this ->pw_wfc ,
56+ this ->pelec = new elecstate::ElecStatePW_SDFT<T, Device>(this ->pw_wfc ,
5757 &(this ->chr ),
5858 &(this ->kv ),
5959 &ucell,
@@ -79,7 +79,7 @@ void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell
7979 }
8080
8181 // 6) prepare some parameters for electronic wave functions initilization
82- this ->p_wf_init = new psi::WFInit<std:: complex < double > , Device>(PARAM.inp .init_wfc ,
82+ this ->p_wf_init = new psi::WFInit<T , Device>(PARAM.inp .init_wfc ,
8383 PARAM.inp .ks_solver ,
8484 PARAM.inp .basis_type ,
8585 PARAM.inp .psi_initializer ,
@@ -118,60 +118,60 @@ void ESolver_SDFT_PW<Device>::before_all_runners(const Input_para& inp, UnitCell
118118
119119 size_t size = stowf.chi0 ->size ();
120120
121- this ->stowf .shchi = new psi::Psi<std:: complex < double > >(this ->kv .get_nks (),
121+ this ->stowf .shchi = new psi::Psi<T >(this ->kv .get_nks (),
122122 this ->stowf .nchip_max ,
123123 this ->wf .npwx ,
124124 this ->kv .ngk .data ());
125125
126- ModuleBase::Memory::record (" SDFT::shchi" , size * sizeof (std:: complex < double > ));
126+ ModuleBase::Memory::record (" SDFT::shchi" , size * sizeof (T ));
127127
128128 if (PARAM.inp .nbands > 0 )
129129 {
130- this ->stowf .chiortho = new psi::Psi<std:: complex < double > >(this ->kv .get_nks (),
130+ this ->stowf .chiortho = new psi::Psi<T >(this ->kv .get_nks (),
131131 this ->stowf .nchip_max ,
132132 this ->wf .npwx ,
133133 this ->kv .ngk .data ());
134- ModuleBase::Memory::record (" SDFT::chiortho" , size * sizeof (std:: complex < double > ));
134+ ModuleBase::Memory::record (" SDFT::chiortho" , size * sizeof (T ));
135135 }
136136
137137 return ;
138138}
139139
140- template <typename Device>
141- void ESolver_SDFT_PW<Device>::before_scf(const int istep)
140+ template <typename T, typename Device>
141+ void ESolver_SDFT_PW<T, Device>::before_scf(const int istep)
142142{
143- ESolver_KS_PW<std:: complex < double > , Device>::before_scf (istep);
143+ ESolver_KS_PW<T , Device>::before_scf (istep);
144144 delete reinterpret_cast <hamilt::HamiltPW<double >*>(this ->p_hamilt );
145- this ->p_hamilt = new hamilt::HamiltSdftPW<std:: complex < double > , Device>(this ->pelec ->pot ,
145+ this ->p_hamilt = new hamilt::HamiltSdftPW<T , Device>(this ->pelec ->pot ,
146146 this ->pw_wfc ,
147147 &this ->kv ,
148148 PARAM.globalv .npol ,
149149 &this ->stoche .emin_sto ,
150150 &this ->stoche .emax_sto );
151- this ->p_hamilt_sto = static_cast <hamilt::HamiltSdftPW<std:: complex < double > , Device>*>(this ->p_hamilt );
151+ this ->p_hamilt_sto = static_cast <hamilt::HamiltSdftPW<T , Device>*>(this ->p_hamilt );
152152
153153 if (istep > 0 && PARAM.inp .nbands_sto != 0 && PARAM.inp .initsto_freq > 0 && istep % PARAM.inp .initsto_freq == 0 )
154154 {
155155 Update_Sto_Orbitals (this ->stowf , PARAM.inp .seed_sto );
156156 }
157157}
158158
159- template <typename Device>
160- void ESolver_SDFT_PW<Device>::iter_finish(int & iter)
159+ template <typename T, typename Device>
160+ void ESolver_SDFT_PW<T, Device>::iter_finish(int & iter)
161161{
162162 // call iter_finish() of ESolver_KS
163- ESolver_KS<std:: complex < double > , Device>::iter_finish (iter);
163+ ESolver_KS<T , Device>::iter_finish (iter);
164164}
165165
166- template <typename Device>
167- void ESolver_SDFT_PW<Device>::after_scf(const int istep)
166+ template <typename T, typename Device>
167+ void ESolver_SDFT_PW<T, Device>::after_scf(const int istep)
168168{
169169 // 1) call after_scf() of ESolver_KS_PW
170- ESolver_KS_PW<std:: complex < double > , Device>::after_scf (istep);
170+ ESolver_KS_PW<T , Device>::after_scf (istep);
171171}
172172
173- template <typename Device>
174- void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
173+ template <typename T, typename Device>
174+ void ESolver_SDFT_PW<T, Device>::hamilt2density(int istep, int iter, double ethr)
175175{
176176 // reset energy
177177 this ->pelec ->f_en .eband = 0.0 ;
@@ -180,19 +180,19 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
180180 // be careful that istep start from 0 and iter start from 1
181181 if (istep == 0 && iter == 1 )
182182 {
183- hsolver::DiagoIterAssist<std:: complex < double > , Device>::need_subspace = false ;
183+ hsolver::DiagoIterAssist<T , Device>::need_subspace = false ;
184184 }
185185 else
186186 {
187- hsolver::DiagoIterAssist<std:: complex < double > , Device>::need_subspace = true ;
187+ hsolver::DiagoIterAssist<T , Device>::need_subspace = true ;
188188 }
189189
190- hsolver::DiagoIterAssist<std:: complex < double > , Device>::PW_DIAG_THR = ethr;
190+ hsolver::DiagoIterAssist<T , Device>::PW_DIAG_THR = ethr;
191191
192- hsolver::DiagoIterAssist<std:: complex < double > , Device>::PW_DIAG_NMAX = PARAM.inp .pw_diag_nmax ;
192+ hsolver::DiagoIterAssist<T , Device>::PW_DIAG_NMAX = PARAM.inp .pw_diag_nmax ;
193193
194194 // hsolver only exists in this function
195- hsolver::HSolverPW_SDFT<Device> hsolver_pw_sdft_obj (
195+ hsolver::HSolverPW_SDFT<T, Device> hsolver_pw_sdft_obj (
196196 &this ->kv ,
197197 this ->pw_wfc ,
198198 &this ->wf ,
@@ -205,10 +205,10 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
205205 PARAM.inp .use_paw ,
206206 PARAM.globalv .use_uspp ,
207207 PARAM.inp .nspin ,
208- hsolver::DiagoIterAssist<std:: complex < double > , Device>::SCF_ITER,
209- hsolver::DiagoIterAssist<std:: complex < double > , Device>::PW_DIAG_NMAX,
210- hsolver::DiagoIterAssist<std:: complex < double > , Device>::PW_DIAG_THR,
211- hsolver::DiagoIterAssist<std:: complex < double > , Device>::need_subspace,
208+ hsolver::DiagoIterAssist<T , Device>::SCF_ITER,
209+ hsolver::DiagoIterAssist<T , Device>::PW_DIAG_NMAX,
210+ hsolver::DiagoIterAssist<T , Device>::PW_DIAG_THR,
211+ hsolver::DiagoIterAssist<T , Device>::need_subspace,
212212 this ->init_psi );
213213
214214 hsolver_pw_sdft_obj.solve (this ->p_hamilt , this ->psi [0 ], this ->pelec , this ->pw_wfc , this ->stowf , istep, iter, false );
@@ -240,14 +240,14 @@ void ESolver_SDFT_PW<Device>::hamilt2density(int istep, int iter, double ethr)
240240#endif
241241}
242242
243- template <typename Device>
244- double ESolver_SDFT_PW<Device>::cal_energy()
243+ template <typename T, typename Device>
244+ double ESolver_SDFT_PW<T, Device>::cal_energy()
245245{
246246 return this ->pelec ->f_en .etot ;
247247}
248248
249- template <typename Device>
250- void ESolver_SDFT_PW<Device>::cal_force(ModuleBase::matrix& force)
249+ template <typename T, typename Device>
250+ void ESolver_SDFT_PW<T, Device>::cal_force(ModuleBase::matrix& force)
251251{
252252 Sto_Forces ff (GlobalC::ucell.nat );
253253
@@ -262,8 +262,8 @@ void ESolver_SDFT_PW<Device>::cal_force(ModuleBase::matrix& force)
262262 this ->stowf );
263263}
264264
265- template <typename Device>
266- void ESolver_SDFT_PW<Device>::cal_stress(ModuleBase::matrix& stress)
265+ template <typename T, typename Device>
266+ void ESolver_SDFT_PW<T, Device>::cal_stress(ModuleBase::matrix& stress)
267267{
268268 Sto_Stress_PW ss;
269269 ss.cal_stress (stress,
@@ -280,8 +280,8 @@ void ESolver_SDFT_PW<Device>::cal_stress(ModuleBase::matrix& stress)
280280 GlobalC::ucell);
281281}
282282
283- template <typename Device>
284- void ESolver_SDFT_PW<Device>::after_all_runners()
283+ template <typename T, typename Device>
284+ void ESolver_SDFT_PW<T, Device>::after_all_runners()
285285{
286286 GlobalV::ofs_running << " \n\n --------------------------------------------" << std::endl;
287287 GlobalV::ofs_running << std::setprecision (16 );
@@ -291,7 +291,7 @@ void ESolver_SDFT_PW<Device>::after_all_runners()
291291}
292292
293293template <>
294- void ESolver_SDFT_PW<base_device::DEVICE_CPU>::after_all_runners()
294+ void ESolver_SDFT_PW<std:: complex < double >, base_device::DEVICE_CPU>::after_all_runners()
295295{
296296
297297 GlobalV::ofs_running << " \n\n --------------------------------------------" << std::endl;
@@ -341,8 +341,8 @@ void ESolver_SDFT_PW<base_device::DEVICE_CPU>::after_all_runners()
341341 }
342342}
343343
344- template <typename Device>
345- void ESolver_SDFT_PW<Device>::others(const int istep)
344+ template <typename T, typename Device>
345+ void ESolver_SDFT_PW<T, Device>::others(const int istep)
346346{
347347 ModuleBase::TITLE (" ESolver_SDFT_PW" , " others" );
348348
@@ -352,14 +352,14 @@ void ESolver_SDFT_PW<Device>::others(const int istep)
352352 }
353353 else
354354 {
355- ModuleBase::WARNING_QUIT (" ESolver_SDFT_PW<Device>::others" , " CALCULATION type not supported" );
355+ ModuleBase::WARNING_QUIT (" ESolver_SDFT_PW<T, Device>::others" , " CALCULATION type not supported" );
356356 }
357357
358358 return ;
359359}
360360
361- template <typename Device>
362- void ESolver_SDFT_PW<Device>::nscf()
361+ template <typename T, typename Device>
362+ void ESolver_SDFT_PW<T, Device>::nscf()
363363{
364364 ModuleBase::TITLE (" ESolver_SDFT_PW" , " nscf" );
365365 ModuleBase::timer::tick (" ESolver_SDFT_PW" , " nscf" );
@@ -382,5 +382,6 @@ void ESolver_SDFT_PW<Device>::nscf()
382382 return ;
383383}
384384
385- template class ESolver_SDFT_PW <base_device::DEVICE_CPU>;
385+ // template class ESolver_SDFT_PW<std::complex<float>, base_device::DEVICE_CPU>;
386+ template class ESolver_SDFT_PW <std::complex <double >, base_device::DEVICE_CPU>;
386387} // namespace ModuleESolver
0 commit comments