Skip to content

Commit 921c2e0

Browse files
committed
refator wavefunc
1 parent 46806ed commit 921c2e0

File tree

1 file changed

+14
-185
lines changed

1 file changed

+14
-185
lines changed

source/module_psi/wavefunc.cpp

Lines changed: 14 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -186,176 +186,28 @@ int wavefunc::get_starting_nw() const
186186
namespace hamilt
187187
{
188188

189-
void diago_PAO_in_pw_k2(const int& ik,
190-
psi::Psi<std::complex<float>>& wvf,
189+
template <>
190+
void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
191+
const int& ik,
192+
psi::Psi<std::complex<float>, base_device::DEVICE_CPU>& wvf,
191193
ModulePW::PW_Basis_K* wfc_basis,
192194
wavefunc* p_wf,
193195
const ModuleBase::realArray& tab_at,
194196
const int& lmaxkb,
195-
hamilt::Hamilt<std::complex<float>>* phm_in)
197+
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* phm_in)
196198
{
197-
ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2");
198-
199-
const int nbasis = wvf.get_nbasis();
200-
const int nbands = wvf.get_nbands();
201-
const int current_nbasis = wfc_basis->npwk[ik];
202-
203-
if (PARAM.inp.init_wfc == "file")
204-
{
205-
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
206-
std::stringstream filename;
207-
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
208-
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat";
209-
ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom);
210-
211-
std::vector<std::complex<float>> s_wfcatom(nbands * nbasis);
212-
castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis);
213-
214-
if (PARAM.inp.ks_solver == "cg")
215-
{
216-
std::vector<float> etfile(nbands, 0.0);
217-
if (phm_in != nullptr)
218-
{
219-
hsolver::DiagoIterAssist<std::complex<float>>::diagH_subspace_init(phm_in,
220-
s_wfcatom.data(),
221-
wfcatom.nr,
222-
wfcatom.nc,
223-
wvf,
224-
etfile.data());
225-
return;
226-
}
227-
else
228-
{
229-
ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!");
230-
}
231-
}
232-
233-
assert(nbands <= wfcatom.nr);
234-
for (int ib = 0; ib < nbands; ib++)
235-
{
236-
for (int ig = 0; ig < nbasis; ig++)
237-
{
238-
wvf(ib, ig) = s_wfcatom[ib * nbasis + ig];
239-
}
240-
}
241-
return;
242-
}
243-
244-
const int starting_nw = p_wf->get_starting_nw();
245-
if (starting_nw == 0)
246-
{
247-
return;
248-
}
249-
assert(starting_nw > 0);
250-
std::vector<float> etatom(starting_nw, 0.0);
251-
252-
// special case here! use Psi(k-1) for the initialization of Psi(k)
253-
// this method should be tested.
254-
/*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0)
255-
{
256-
//this is memsaver case
257-
if(wvf.get_nk() == 1)
258-
{
259-
return;
260-
}
261-
else
262-
{
263-
ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands());
264-
return;
265-
}
266-
}
267-
*/
268-
269-
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
270-
{
271-
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
272-
273-
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
274-
{
275-
if (phm_in != nullptr)
276-
{
277-
hsolver::DiagoIterAssist<std::complex<float>>::diagH_subspace(phm_in, wvf, wvf, etatom.data());
278-
return;
279-
}
280-
else
281-
{
282-
ModuleBase::WARNING_QUIT("wavefunc", "Hamiltonian does not exist!");
283-
}
284-
}
285-
}
286-
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
287-
{
288-
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
289-
if (PARAM.inp.test_wf)
290-
{
291-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
292-
}
293-
294-
p_wf->atomic_wfc(ik,
295-
current_nbasis,
296-
GlobalC::ucell.lmax_ppwf,
297-
lmaxkb,
298-
wfc_basis,
299-
wfcatom,
300-
tab_at,
301-
PARAM.globalv.nqx,
302-
PARAM.globalv.dq);
303-
304-
if (PARAM.inp.init_wfc == "atomic+random"
305-
&& starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
306-
{
307-
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
308-
}
309-
310-
//====================================================
311-
// If not enough atomic wfc are available, complete
312-
// with random wfcs
313-
//====================================================
314-
p_wf->random(wfcatom.c, GlobalC::ucell.natomwfc, nbands, ik, wfc_basis);
315-
316-
// (7) Diago with cg method.
317-
std::vector<std::complex<float>> s_wfcatom(starting_nw * nbasis);
318-
castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, starting_nw * nbasis);
319-
320-
// if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02
321-
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
322-
{
323-
if (phm_in != nullptr)
324-
{
325-
hsolver::DiagoIterAssist<std::complex<float>>::diagH_subspace_init(phm_in,
326-
s_wfcatom.data(),
327-
wfcatom.nr,
328-
wfcatom.nc,
329-
wvf,
330-
etatom.data());
331-
return;
332-
}
333-
else
334-
{
335-
ModuleBase::WARNING_QUIT("wavefunc", "Psi does not exist!");
336-
// this diagonalization method is obsoleted now
337-
// GlobalC::hm.diagH_subspace(ik ,starting_nw, nbands, wfcatom, wfcatom, etatom.data());
338-
}
339-
}
340-
341-
assert(nbands <= wfcatom.nr);
342-
for (int ib = 0; ib < nbands; ib++)
343-
{
344-
for (int ig = 0; ig < nbasis; ig++)
345-
{
346-
wvf(ib, ig) = s_wfcatom[ib * nbasis + ig];
347-
}
348-
}
349-
}
199+
// TODO float func
350200
}
351201

352-
void diago_PAO_in_pw_k2(const int& ik,
353-
psi::Psi<std::complex<double>>& wvf,
202+
template <>
203+
void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
204+
const int& ik,
205+
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& wvf,
354206
ModulePW::PW_Basis_K* wfc_basis,
355207
wavefunc* p_wf,
356208
const ModuleBase::realArray& tab_at,
357209
const int& lmaxkb,
358-
hamilt::Hamilt<std::complex<double>>* phm_in)
210+
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* phm_in)
359211
{
360212
ModuleBase::TITLE("wavefunc", "diago_PAO_in_pw_k2");
361213

@@ -490,33 +342,8 @@ void diago_PAO_in_pw_k2(const int& ik,
490342
}
491343
}
492344

493-
template <>
494-
void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
495-
const int& ik,
496-
psi::Psi<std::complex<float>, base_device::DEVICE_CPU>& wvf,
497-
ModulePW::PW_Basis_K* wfc_basis,
498-
wavefunc* p_wf,
499-
const ModuleBase::realArray& tab_at,
500-
const int& lmaxkb,
501-
hamilt::Hamilt<std::complex<float>, base_device::DEVICE_CPU>* phm_in)
502-
{
503-
diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in);
504-
}
505-
506-
template <>
507-
void diago_PAO_in_pw_k2(const base_device::DEVICE_CPU* ctx,
508-
const int& ik,
509-
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>& wvf,
510-
ModulePW::PW_Basis_K* wfc_basis,
511-
wavefunc* p_wf,
512-
const ModuleBase::realArray& tab_at,
513-
const int& lmaxkb,
514-
hamilt::Hamilt<std::complex<double>, base_device::DEVICE_CPU>* phm_in)
515-
{
516-
diago_PAO_in_pw_k2(ik, wvf, wfc_basis, p_wf, tab_at, lmaxkb, phm_in);
517-
}
518-
519345
#if ((defined __CUDA) || (defined __ROCM))
346+
520347
template <>
521348
void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
522349
const int& ik,
@@ -625,6 +452,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
625452
delmem_cd_op()(gpu_ctx, c_wfcatom);
626453
}
627454
}
455+
628456
template <>
629457
void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
630458
const int& ik,
@@ -733,6 +561,7 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
733561
delmem_zd_op()(gpu_ctx, z_wfcatom);
734562
}
735563
}
564+
736565
#endif
737566

738567
} // namespace hamilt

0 commit comments

Comments
 (0)