Skip to content

Commit 46806ed

Browse files
committed
refactor diago_PAO_in_pw_k2 func
1 parent 6df9240 commit 46806ed

File tree

1 file changed

+34
-43
lines changed

1 file changed

+34
-43
lines changed

source/module_psi/wavefunc.cpp

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,13 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
6161
wanf2[0].create(PARAM.globalv.nlocal, npwx * PARAM.globalv.npol);
6262

6363
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
64-
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx);
64+
const size_t memory_cost
65+
= sizeof(std::complex<double>) * PARAM.globalv.nlocal * (PARAM.globalv.npol * npwx);
6566

6667
std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
6768
ModuleBase::Memory::record("WF::wanf2", memory_cost);
6869
}
69-
70+
7071
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
7172
const size_t memory_cost = sizeof(std::complex<double>) * PARAM.inp.nbands * (PARAM.globalv.npol * npwx);
7273

@@ -89,7 +90,8 @@ psi::Psi<std::complex<double>>* wavefunc::allocate(const int nkstot, const int n
8990
}
9091

9192
// WARNING: put the sizeof() be the first to avoid the overflow of the multiplication of int
92-
const size_t memory_cost = sizeof(std::complex<double>) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol);
93+
const size_t memory_cost
94+
= sizeof(std::complex<double>) * nks2 * PARAM.globalv.nlocal * (npwx * PARAM.globalv.npol);
9395

9496
std::cout << " Memory for wanf2 (MB): " << static_cast<double>(memory_cost) / 1024.0 / 1024.0 << std::endl;
9597
ModuleBase::Memory::record("WF::wanf2", memory_cost);
@@ -206,7 +208,6 @@ void diago_PAO_in_pw_k2(const int& ik,
206208
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat";
207209
ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom);
208210

209-
210211
std::vector<std::complex<float>> s_wfcatom(nbands * nbasis);
211212
castmem_z2c_h2h_op()(cpu_ctx, cpu_ctx, s_wfcatom.data(), wfcatom.c, nbands * nbasis);
212213

@@ -285,9 +286,10 @@ void diago_PAO_in_pw_k2(const int& ik,
285286
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
286287
{
287288
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
288-
if (PARAM.inp.test_wf) {
289+
if (PARAM.inp.test_wf)
290+
{
289291
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
290-
}
292+
}
291293

292294
p_wf->atomic_wfc(ik,
293295
current_nbasis,
@@ -299,7 +301,8 @@ void diago_PAO_in_pw_k2(const int& ik,
299301
PARAM.globalv.nqx,
300302
PARAM.globalv.dq);
301303

302-
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
304+
if (PARAM.inp.init_wfc == "atomic+random"
305+
&& starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
303306
{
304307
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
305308
}
@@ -365,10 +368,9 @@ void diago_PAO_in_pw_k2(const int& ik,
365368
ModuleBase::ComplexMatrix wfcatom(nbands, nbasis);
366369
std::stringstream filename;
367370
int ik_tot = K_Vectors::get_ik_global(ik, p_wf->nkstot);
368-
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat";
371+
filename << PARAM.globalv.global_readin_dir << "WAVEFUNC" << ik_tot + 1 << ".dat";
369372
ModuleIO::read_wfc_pw(filename.str(), wfc_basis, ik, p_wf->nkstot, wfcatom);
370373

371-
372374
if (PARAM.inp.ks_solver == "cg")
373375
{
374376
std::vector<double> etfile(nbands, 0.0);
@@ -398,41 +400,18 @@ void diago_PAO_in_pw_k2(const int& ik,
398400
}
399401
return;
400402
}
401-
402-
// special case here! use Psi(k-1) for the initialization of Psi(k)
403-
// this method should be tested.
404-
/*if(PARAM.inp.calculation == "nscf" && GlobalC::ucell.natomwfc == 0 && ik>0)
405-
{
406-
//this is memsaver case
407-
if(wvf.get_nk() == 1)
408-
{
409-
return;
410-
}
411-
else
412-
{
413-
ModuleBase::GlobalFunc::COPYARRAY(&wvf(ik-1, 0, 0), &wvf(ik, 0, 0), wvf.get_nbasis()* wvf.get_nbands());
414-
return;
415-
}
416-
}
417-
*/
418-
419-
const int starting_nw = p_wf->get_starting_nw();
420-
if (starting_nw == 0)
421-
{
422-
return;
423-
}
424-
425-
assert(starting_nw > 0);
426-
std::vector<double> etatom(starting_nw, 0.0);
427-
428-
if (PARAM.inp.init_wfc == "random" || (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
403+
else if (PARAM.inp.init_wfc == "random"
404+
|| (PARAM.inp.init_wfc.substr(0, 6) == "atomic" && GlobalC::ucell.natomwfc == 0))
429405
{
430406
p_wf->random(wvf.get_pointer(), 0, nbands, ik, wfc_basis);
431-
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
407+
408+
if (PARAM.inp.ks_solver == "cg")
432409
{
410+
std::vector<double> etrandom(nbands, 0.0);
411+
433412
if (phm_in != nullptr)
434413
{
435-
hsolver::DiagoIterAssist<std::complex<double>>::diagH_subspace(phm_in, wvf, wvf, etatom.data());
414+
hsolver::DiagoIterAssist<std::complex<double>>::diagH_subspace(phm_in, wvf, wvf, etrandom.data());
436415
return;
437416
}
438417
else
@@ -443,7 +422,15 @@ void diago_PAO_in_pw_k2(const int& ik,
443422
}
444423
else if (PARAM.inp.init_wfc.substr(0, 6) == "atomic")
445424
{
446-
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis); // added by zhengdy-soc
425+
const int starting_nw = p_wf->get_starting_nw();
426+
if (starting_nw == 0)
427+
{
428+
return;
429+
}
430+
assert(starting_nw > 0);
431+
432+
ModuleBase::ComplexMatrix wfcatom(starting_nw, nbasis);
433+
447434
if (PARAM.inp.test_wf)
448435
{
449436
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "starting_nw", starting_nw);
@@ -459,7 +446,8 @@ void diago_PAO_in_pw_k2(const int& ik,
459446
PARAM.globalv.nqx,
460447
PARAM.globalv.dq);
461448

462-
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
449+
if (PARAM.inp.init_wfc == "atomic+random"
450+
&& starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
463451
{
464452
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
465453
}
@@ -474,6 +462,7 @@ void diago_PAO_in_pw_k2(const int& ik,
474462
// if(GlobalV::DIAGO_TYPE == "cg") xiaohui modify 2013-09-02
475463
if (PARAM.inp.ks_solver == "cg") // xiaohui add 2013-09-02
476464
{
465+
std::vector<double> etatom(starting_nw, 0.0);
477466
if (phm_in != nullptr)
478467
{
479468
hsolver::DiagoIterAssist<std::complex<double>>::diagH_subspace_init(phm_in,
@@ -573,7 +562,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
573562
tab_at,
574563
PARAM.globalv.nqx,
575564
PARAM.globalv.dq);
576-
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
565+
if (PARAM.inp.init_wfc == "atomic+random"
566+
&& starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
577567
{
578568
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
579569
}
@@ -679,7 +669,8 @@ void diago_PAO_in_pw_k2(const base_device::DEVICE_GPU* ctx,
679669
tab_at,
680670
PARAM.globalv.nqx,
681671
PARAM.globalv.dq);
682-
if (PARAM.inp.init_wfc == "atomic+random" && starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
672+
if (PARAM.inp.init_wfc == "atomic+random"
673+
&& starting_nw == GlobalC::ucell.natomwfc) // added by qianrui 2021-5-16
683674
{
684675
p_wf->atomicrandom(wfcatom, 0, starting_nw, ik, wfc_basis);
685676
}

0 commit comments

Comments
 (0)