@@ -186,176 +186,28 @@ int wavefunc::get_starting_nw() const
186186namespace hamilt
187187{
188188
189- void diago_PAO_in_pw_k2 (const int & ik,
190- psi::Psi<std::complex <float >>& wvf,
189+ template <>
190+ void diago_PAO_in_pw_k2 (const base_device::DEVICE_CPU* ctx,
191+ const int & ik,
192+ psi::Psi<std::complex <float >, base_device::DEVICE_CPU>& wvf,
191193 ModulePW::PW_Basis_K* wfc_basis,
192194 wavefunc* p_wf,
193195 const ModuleBase::realArray& tab_at,
194196 const int & lmaxkb,
195- hamilt::Hamilt<std::complex <float >>* phm_in)
197+ hamilt::Hamilt<std::complex <float >, base_device::DEVICE_CPU >* phm_in)
196198{
197- ModuleBase::TITLE (" wavefunc" , " diago_PAO_in_pw_k2" );
198-
199- const int nbasis = wvf.get_nbasis ();
200- const int nbands = wvf.get_nbands ();
201- const int current_nbasis = wfc_basis->npwk [ik];
202-
203- if (PARAM.inp .init_wfc == " file" )
204- {
205- ModuleBase::ComplexMatrix wfcatom (nbands, nbasis);
206- std::stringstream filename;
207- int ik_tot = K_Vectors::get_ik_global (ik, p_wf->nkstot );
208- filename << PARAM.globalv .global_readin_dir << " WAVEFUNC" << ik_tot + 1 << " .dat" ;
209- ModuleIO::read_wfc_pw (filename.str (), wfc_basis, ik, p_wf->nkstot , wfcatom);
210-
211- std::vector<std::complex <float >> s_wfcatom (nbands * nbasis);
212- castmem_z2c_h2h_op ()(cpu_ctx, cpu_ctx, s_wfcatom.data (), wfcatom.c , nbands * nbasis);
213-
214- if (PARAM.inp .ks_solver == " cg" )
215- {
216- std::vector<float > etfile (nbands, 0.0 );
217- if (phm_in != nullptr )
218- {
219- hsolver::DiagoIterAssist<std::complex <float >>::diagH_subspace_init (phm_in,
220- s_wfcatom.data (),
221- wfcatom.nr ,
222- wfcatom.nc ,
223- wvf,
224- etfile.data ());
225- return ;
226- }
227- else
228- {
229- ModuleBase::WARNING_QUIT (" wavefunc" , " Psi does not exist!" );
230- }
231- }
232-
233- assert (nbands <= wfcatom.nr );
234- for (int ib = 0 ; ib < nbands; ib++)
235- {
236- for (int ig = 0 ; ig < nbasis; ig++)
237- {
238- wvf (ib, ig) = s_wfcatom[ib * nbasis + ig];
239- }
240- }
241- return ;
242- }
243-
244- const int starting_nw = p_wf->get_starting_nw ();
245- if (starting_nw == 0 )
246- {
247- return ;
248- }
249- assert (starting_nw > 0 );
250- std::vector<float > etatom (starting_nw, 0.0 );
251-
252- // special case here! use Psi(k-1) for the initialization of Psi(k)
253- // this method should be tested.
254- /* if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0)
255- {
256- //this is memsaver case
257- if(wvf.get_nk() == 1)
258- {
259- return;
260- }
261- else
262- {
263- ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands());
264- return;
265- }
266- }
267- */
268-
269- if (PARAM.inp .init_wfc == " random" || (PARAM.inp .init_wfc .substr (0 , 6 ) == " atomic" && GlobalC::ucell.natomwfc == 0 ))
270- {
271- p_wf->random (wvf.get_pointer (), 0 , nbands, ik, wfc_basis);
272-
273- if (PARAM.inp .ks_solver == " cg" ) // xiaohui add 2013-09-02
274- {
275- if (phm_in != nullptr )
276- {
277- hsolver::DiagoIterAssist<std::complex <float >>::diagH_subspace (phm_in, wvf, wvf, etatom.data ());
278- return ;
279- }
280- else
281- {
282- ModuleBase::WARNING_QUIT (" wavefunc" , " Hamiltonian does not exist!" );
283- }
284- }
285- }
286- else if (PARAM.inp .init_wfc .substr (0 , 6 ) == " atomic" )
287- {
288- ModuleBase::ComplexMatrix wfcatom (starting_nw, nbasis); // added by zhengdy-soc
289- if (PARAM.inp .test_wf )
290- {
291- ModuleBase::GlobalFunc::OUT (GlobalV::ofs_running, " starting_nw" , starting_nw);
292- }
293-
294- p_wf->atomic_wfc (ik,
295- current_nbasis,
296- GlobalC::ucell.lmax_ppwf ,
297- lmaxkb,
298- wfc_basis,
299- wfcatom,
300- tab_at,
301- PARAM.globalv .nqx ,
302- PARAM.globalv .dq );
303-
304- if (PARAM.inp .init_wfc == " atomic+random"
305- && starting_nw == GlobalC::ucell.natomwfc ) // added by qianrui 2021-5-16
306- {
307- p_wf->atomicrandom (wfcatom, 0 , starting_nw, ik, wfc_basis);
308- }
309-
310- // ====================================================
311- // If not enough atomic wfc are available, complete
312- // with random wfcs
313- // ====================================================
314- p_wf->random (wfcatom.c , GlobalC::ucell.natomwfc , nbands, ik, wfc_basis);
315-
316- // (7) Diago with cg method.
317- std::vector<std::complex <float >> s_wfcatom (starting_nw * nbasis);
318- castmem_z2c_h2h_op ()(cpu_ctx, cpu_ctx, s_wfcatom.data (), wfcatom.c , starting_nw * nbasis);
319-
320- // if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02
321- if (PARAM.inp .ks_solver == " cg" ) // xiaohui add 2013-09-02
322- {
323- if (phm_in != nullptr )
324- {
325- hsolver::DiagoIterAssist<std::complex <float >>::diagH_subspace_init (phm_in,
326- s_wfcatom.data (),
327- wfcatom.nr ,
328- wfcatom.nc ,
329- wvf,
330- etatom.data ());
331- return ;
332- }
333- else
334- {
335- ModuleBase::WARNING_QUIT (" wavefunc" , " Psi does not exist!" );
336- // this diagonalization method is obsoleted now
337- // GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
338- }
339- }
340-
341- assert (nbands <= wfcatom.nr );
342- for (int ib = 0 ; ib < nbands; ib++)
343- {
344- for (int ig = 0 ; ig < nbasis; ig++)
345- {
346- wvf (ib, ig) = s_wfcatom[ib * nbasis + ig];
347- }
348- }
349- }
199+ // TODO float func
350200}
351201
352- void diago_PAO_in_pw_k2 (const int & ik,
353- psi::Psi<std::complex <double >>& wvf,
202+ template <>
203+ void diago_PAO_in_pw_k2 (const base_device::DEVICE_CPU* ctx,
204+ const int & ik,
205+ psi::Psi<std::complex <double >, base_device::DEVICE_CPU>& wvf,
354206 ModulePW::PW_Basis_K* wfc_basis,
355207 wavefunc* p_wf,
356208 const ModuleBase::realArray& tab_at,
357209 const int & lmaxkb,
358- hamilt::Hamilt<std::complex <double >>* phm_in)
210+ hamilt::Hamilt<std::complex <double >, base_device::DEVICE_CPU >* phm_in)
359211{
360212 ModuleBase::TITLE (" wavefunc" , " diago_PAO_in_pw_k2" );
361213
@@ -490,33 +342,8 @@ void diago_PAO_in_pw_k2(const int& ik,
490342 }
491343}
492344
493- template <>
494- void diago_PAO_in_pw_k2 (const base_device::DEVICE_CPU* ctx,
495- const int & ik,
496- psi::Psi<std::complex <float >, base_device::DEVICE_CPU>& wvf,
497- ModulePW::PW_Basis_K* wfc_basis,
498- wavefunc* p_wf,
499- const ModuleBase::realArray& tab_at,
500- const int & lmaxkb,
501- hamilt::Hamilt<std::complex <float >, base_device::DEVICE_CPU>* phm_in)
502- {
503- diago_PAO_in_pw_k2 (ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in);
504- }
505-
506- template <>
507- void diago_PAO_in_pw_k2 (const base_device::DEVICE_CPU* ctx,
508- const int & ik,
509- psi::Psi<std::complex <double >, base_device::DEVICE_CPU>& wvf,
510- ModulePW::PW_Basis_K* wfc_basis,
511- wavefunc* p_wf,
512- const ModuleBase::realArray& tab_at,
513- const int & lmaxkb,
514- hamilt::Hamilt<std::complex <double >, base_device::DEVICE_CPU>* phm_in)
515- {
516- diago_PAO_in_pw_k2 (ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in);
517- }
518-
519345#if ((defined __CUDA) || (defined __ROCM))
346+
520347template <>
521348void diago_PAO_in_pw_k2 (const base_device::DEVICE_GPU* ctx,
522349 const int & ik,
@@ -625,6 +452,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
625452 delmem_cd_op ()(gpu_ctx, c_wfcatom);
626453 }
627454}
455+
628456template <>
629457void diago_PAO_in_pw_k2 (const base_device::DEVICE_GPU* ctx,
630458 const int & ik,
@@ -733,6 +561,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
733561 delmem_zd_op ()(gpu_ctx, z_wfcatom);
734562 }
735563}
564+
736565#endif
737566
738567} // namespace hamilt
0 commit comments