Skip to content

Commit 2352026

Browse files
committed
test
1 parent ee4ad57 commit 2352026

File tree

2 files changed

+62
-41
lines changed

2 files changed

+62
-41
lines changed

source/module_psi/psi_init.cpp

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ PSIInit<T, Device>::PSIInit(const std::string& init_wfc_in,
2525
this->basis_type = basis_type_in;
2626
this->use_psiinitializer = use_psiinitializer_in;
2727
this->pw_wfc = pw_wfc_in;
28+
29+
if (PARAM.inp.init_wfc == "file" || PARAM.inp.device == "gpu")
30+
{
31+
this->init_psi_method = "old"; // old method;
32+
}
33+
else
34+
{
35+
this->init_psi_method = "new"; // new method;
36+
}
2837
}
2938

3039
template <typename T, typename Device>
@@ -37,53 +46,60 @@ void PSIInit<T, Device>::prepare_init(Structure_Factor* p_sf,
3746
#endif
3847
pseudopot_cell_vnl* p_ppcell)
3948
{
40-
if (!this->use_psiinitializer)
49+
if (this->init_psi_method == "old")
4150
{
4251
return;
4352
}
44-
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
45-
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
46-
ModuleBase::timer::tick("PSIInit", "prepare_init");
47-
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
48-
{
49-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
50-
}
51-
else if (this->init_wfc == "atomic")
52-
{
53-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic<T, Device>());
54-
}
55-
else if (this->init_wfc == "random")
56-
{
57-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
58-
}
59-
else if (this->init_wfc == "nao")
60-
{
61-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao<T, Device>());
62-
}
63-
else if (this->init_wfc == "atomic+random")
64-
{
65-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic_random<T, Device>());
66-
}
67-
else if (this->init_wfc == "nao+random")
68-
{
69-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao_random<T, Device>());
70-
}
7153
else
7254
{
73-
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
74-
}
55+
ModuleBase::timer::tick("PSIInit", "prepare_init");
7556

76-
//! function polymorphism is moved from constructor to function initialize.
77-
//! Two slightly different implementation are for MPI and serial case, respectively.
57+
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
58+
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
59+
60+
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
61+
{
62+
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
63+
}
64+
else if (this->init_wfc == "atomic")
65+
{
66+
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic<T, Device>());
67+
}
68+
else if (this->init_wfc == "random")
69+
{
70+
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
71+
}
72+
else if (this->init_wfc == "nao")
73+
{
74+
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao<T, Device>());
75+
}
76+
else if (this->init_wfc == "atomic+random")
77+
{
78+
this->psi_init
79+
= std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic_random<T, Device>());
80+
}
81+
else if (this->init_wfc == "nao+random")
82+
{
83+
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao_random<T, Device>());
84+
}
85+
else
86+
{
87+
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
88+
}
89+
90+
//! function polymorphism is moved from constructor to function initialize.
91+
//! Two slightly different implementation are for MPI and serial case, respectively.
7892
#ifdef __MPI
79-
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
93+
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
8094
#else
81-
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
95+
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
8296
#endif
8397

84-
// always new->initialize->tabulate->allocate->proj_ao_onkG
85-
this->psi_init->tabulate();
86-
ModuleBase::timer::tick("PSIInit", "prepare_init");
98+
// always new->initialize->tabulate->allocate->proj_ao_onkG
99+
this->psi_init->tabulate();
100+
101+
ModuleBase::timer::tick("PSIInit", "prepare_init");
102+
}
87103
}
88104

89105
template <typename T, typename Device>
@@ -108,7 +124,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
108124
// the basis (representation) with operator (hamiltonian) and solver (diagonalization).
109125
// This feature requires feasible Linear Algebra library in-built in ABACUS, which
110126
// is not ready yet.
111-
if (this->use_psiinitializer) // new method
127+
if (this->init_psi_method == "new") // new method
112128
{
113129
// psi_initializer drag initialization of pw wavefunction out of HSolver, make psi
114130
// initialization decoupled with HSolver (diagonalization) procedure.
@@ -137,13 +153,15 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
137153
template <typename T, typename Device>
138154
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
139155
{
140-
if (this->use_psiinitializer)
156+
if (this->init_psi_method == "new")
141157
{
142158
} // do not need to do anything because the interpolate table is unchanged
143159
else // old initialization method, used in EXX calculation
144160
{
145161
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
146-
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
162+
this->wf_old.init_at_1(
163+
p_sf,
164+
&p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
147165
}
148166
}
149167

@@ -157,7 +175,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
157175
{
158176
ModuleBase::timer::tick("PSIInit", "initialize_psi");
159177

160-
if (PARAM.inp.psi_initializer)
178+
if (this->init_psi_method == "new")
161179
{
162180
// if psig is not allocated before, allocate it
163181
if (!this->psi_init->psig_use_count())

source/module_psi/psi_init.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class PSIInit
9494
ModulePW::PW_Basis_K* pw_wfc = nullptr;
9595

9696
Device* ctx = {};
97+
98+
std::string init_psi_method = "old";
99+
// old or new
97100
};
98101

99102
} // namespace psi

0 commit comments

Comments
 (0)