Skip to content

Commit f54fc18

Browse files
committed
set psi settings
1 parent 33b18fb commit f54fc18

File tree

2 files changed

+79
-53
lines changed

2 files changed

+79
-53
lines changed

source/module_psi/psi_init.cpp

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
#include "module_base/tool_quit.h"
66
#include "module_hsolver/diago_iter_assist.h"
77
#include "module_parameter/parameter.h"
8-
#include "module_psi/psi_initializer_random.h"
98
#include "module_psi/psi_initializer_atomic.h"
109
#include "module_psi/psi_initializer_atomic_random.h"
1110
#include "module_psi/psi_initializer_nao.h"
1211
#include "module_psi/psi_initializer_nao_random.h"
12+
#include "module_psi/psi_initializer_random.h"
1313
namespace psi
1414
{
1515

@@ -25,17 +25,33 @@ PSIInit<T, Device>::PSIInit(const std::string& init_wfc_in,
2525
this->basis_type = basis_type_in;
2626
this->use_psiinitializer = use_psiinitializer_in;
2727
this->pw_wfc = pw_wfc_in;
28+
29+
if (PARAM.inp.psi_initializer == true)
30+
{
31+
this->init_psi_method = "new";
32+
}
33+
else
34+
{
35+
if (PARAM.inp.init_wfc == "file" || PARAM.inp.device == "gpu" || PARAM.inp.esolver_type == "sdft")
36+
{
37+
this->init_psi_method = "old"; // old method;
38+
}
39+
else
40+
{
41+
this->init_psi_method = "new"; // new method;
42+
}
43+
}
2844
}
2945

3046
template <typename T, typename Device>
3147
PSIInit<T, Device>::~PSIInit()
3248
{
33-
if (this->use_psiinitializer)
49+
if (this->init_psi_method == "new")
3450
{
3551
{
3652
this->psi_init->deallocate_psig();
3753
// delete this->psi_init;
38-
// this->psi_init = nullptr;
54+
// this->psi_init = nullptr;
3955
}
4056
}
4157
}
@@ -50,53 +66,56 @@ void PSIInit<T, Device>::prepare_init(Structure_Factor* p_sf,
5066
#endif
5167
pseudopot_cell_vnl* p_ppcell)
5268
{
53-
if (!this->use_psiinitializer)
69+
if (this->init_psi_method == "old")
5470
{
5571
return;
5672
}
57-
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
58-
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
59-
ModuleBase::timer::tick("PSIInit", "prepare_init");
60-
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
61-
{
62-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
63-
}
64-
else if (this->init_wfc == "atomic")
65-
{
66-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomic<T, Device>());
67-
}
68-
else if (this->init_wfc == "random")
69-
{
70-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
71-
}
72-
else if (this->init_wfc == "nao")
73-
{
74-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAO<T, Device>());
75-
}
76-
else if (this->init_wfc == "atomic+random")
77-
{
78-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomicRandom<T, Device>());
79-
}
80-
else if (this->init_wfc == "nao+random")
81-
{
82-
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAORandom<T, Device>());
83-
}
8473
else
8574
{
86-
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
87-
}
75+
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
76+
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
77+
ModuleBase::timer::tick("PSIInit", "prepare_init");
78+
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
79+
{
80+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
81+
}
82+
else if (this->init_wfc == "atomic")
83+
{
84+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomic<T, Device>());
85+
}
86+
else if (this->init_wfc == "random")
87+
{
88+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
89+
}
90+
else if (this->init_wfc == "nao")
91+
{
92+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAO<T, Device>());
93+
}
94+
else if (this->init_wfc == "atomic+random")
95+
{
96+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomicRandom<T, Device>());
97+
}
98+
else if (this->init_wfc == "nao+random")
99+
{
100+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAORandom<T, Device>());
101+
}
102+
else
103+
{
104+
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
105+
}
88106

89-
//! function polymorphism is moved from constructor to function initialize.
90-
//! Two slightly different implementation are for MPI and serial case, respectively.
107+
//! function polymorphism is moved from constructor to function initialize.
108+
//! Two slightly different implementation are for MPI and serial case, respectively.
91109
#ifdef __MPI
92-
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
110+
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
93111
#else
94-
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
112+
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
95113
#endif
96114

97-
// always new->initialize->tabulate->allocate->proj_ao_onkG
98-
this->psi_init->tabulate();
99-
ModuleBase::timer::tick("PSIInit", "prepare_init");
115+
// always new->initialize->tabulate->allocate->proj_ao_onkG
116+
this->psi_init->tabulate();
117+
ModuleBase::timer::tick("PSIInit", "prepare_init");
118+
}
100119
}
101120

102121
template <typename T, typename Device>
@@ -121,7 +140,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
121140
// the basis (representation) with operator (hamiltonian) and solver (diagonalization).
122141
// This feature requires feasible Linear Algebra library in-built in ABACUS, which
123142
// is not ready yet.
124-
if (this->use_psiinitializer) // new method
143+
if (this->init_psi_method == "new") // new method
125144
{
126145
// PsiInitializer drag initialization of pw wavefunction out of HSolver, make psi
127146
// initialization decoupled with HSolver (diagonalization) procedure.
@@ -150,28 +169,31 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
150169
template <typename T, typename Device>
151170
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
152171
{
153-
if (this->use_psiinitializer)
172+
if (this->init_psi_method == "new")
154173
{
155174
} // do not need to do anything because the interpolate table is unchanged
156175
else // old initialization method, used in EXX calculation
157176
{
158177
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
159-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
178+
this->wf_old.init_at_1(
179+
p_sf,
180+
&p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
160181
}
161182
}
162183

163184
// in the following function, the psi on Device will be initialized with the CPU psi
164185
template <typename T, typename Device>
165-
void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi, // the one always on CPU
166-
psi::Psi<T, Device>* kspw_psi, // the one may be on GPU. In CPU case, it is the same as psi
167-
hamilt::Hamilt<T, Device>* p_hamilt,
168-
const pseudopot_cell_vnl& nlpp,
169-
std::ofstream& ofs_running,
170-
const bool is_already_initpsi)
186+
void PSIInit<T, Device>::initialize_psi(
187+
Psi<std::complex<double>>* psi, // the one always on CPU
188+
psi::Psi<T, Device>* kspw_psi, // the one may be on GPU. In CPU case, it is the same as psi
189+
hamilt::Hamilt<T, Device>* p_hamilt,
190+
const pseudopot_cell_vnl& nlpp,
191+
std::ofstream& ofs_running,
192+
const bool is_already_initpsi)
171193
{
172194
ModuleBase::timer::tick("PSIInit", "initialize_psi");
173195

174-
if (PARAM.inp.psi_initializer)
196+
if (this->init_psi_method == "new")
175197
{
176198
// if psig is not allocated before, allocate it
177199
if (!this->psi_init->psig_use_count())
@@ -183,7 +205,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi, //
183205
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
184206
for (int ik = 0; ik < this->pw_wfc->nks; ik++)
185207
{
186-
//! Fix the wavefunction to initialize at given kpoint.
208+
//! Fix the wavefunction to initialize at given kpoint.
187209
// This will fix the kpoint for CPU case. For GPU, we should additionally call fix_k for kspw_psi
188210
psi->fix_k(ik);
189211

@@ -256,7 +278,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi, //
256278
}
257279

258280
// for the Davidson method, we just copy the wavefunction (partially)
259-
// For GPU: although this is simply the copy operation, if GPU present, this should be a data sending operation
281+
// For GPU: although this is simply the copy operation, if GPU present, this should be a data sending
282+
// operation
260283
for (int iband = 0; iband < kspw_psi->get_nbands(); iband++)
261284
{
262285
for (int ibasis = 0; ibasis < kspw_psi->get_nbasis(); ibasis++)
@@ -266,7 +289,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi, //
266289
}
267290
} // end k-point loop
268291

269-
if (this->basis_type != "lcao_in_pw") // if not LCAO_IN_PW case, we can release the memory of psig after initailization is done.
292+
if (this->basis_type
293+
!= "lcao_in_pw") // if not LCAO_IN_PW case, we can release the memory of psig after initailization is done.
270294
{
271295
this->psi_init->deallocate_psig();
272296
}

source/module_psi/psi_init.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class PSIInit
9494
ModulePW::PW_Basis_K* pw_wfc = nullptr;
9595

9696
Device* ctx = {};
97+
98+
std::string init_psi_method = "old";
9799
};
98100

99101
} // namespace psi

0 commit comments

Comments
 (0)