Skip to content

Commit 3d05448

Browse files
committed
initial commit
1 parent ee4ad57 commit 3d05448

20 files changed

+293
-249
lines changed

source/module_esolver/esolver_ks_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
5252
//! hide the psi in ESolver_KS for tmp use
5353
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi = nullptr;
5454

55-
// psi_initializer controller
55+
// PsiInitializer controller
5656
psi::PSIInit<T, Device>* p_wf_init = nullptr;
5757

5858
Device* ctx = {};

source/module_io/read_input_item_postprocess.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ void ReadInput::item_postprocess()
196196
197197
In the future lcao_in_pw will have its own ESolver.
198198
199-
2023/12/22 use new psi_initializer to expand numerical
199+
2023/12/22 use new PsiInitializer to expand numerical
200200
atomic orbitals, ykhuang
201201
*/
202202
if (para.input.towannier90 && para.input.basis_type == "lcao_in_pw")

source/module_io/read_input_item_system.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ void ReadInput::item_system()
502502
}
503503
{
504504
Input_Item item("psi_initializer");
505-
item.annotation = "whether to use psi_initializer";
505+
item.annotation = "whether to use PsiInitializer";
506506
item.reset_value = [](const Input_Item& item, Parameter& para) {
507507
if (para.input.basis_type == "lcao_in_pw")
508508
{

source/module_io/to_wannier90_lcao_in_pw.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ void toWannier90_LCAO_IN_PW::calculate(
4343

4444
Structure_Factor* sf_ptr = const_cast<Structure_Factor*>(&sf);
4545
ModulePW::PW_Basis_K* wfcpw_ptr = const_cast<ModulePW::PW_Basis_K*>(wfcpw);
46-
this->psi_init_ = new psi_initializer_nao<std::complex<double>, base_device::DEVICE_CPU>();
46+
this->psi_init_ = new PsiInitializerNAO<std::complex<double>, base_device::DEVICE_CPU>();
4747
#ifdef __MPI
4848
this->psi_init_->initialize(sf_ptr, wfcpw_ptr, &(GlobalC::ucell), &(GlobalC::Pkpoints), 1, nullptr, GlobalV::MY_RANK);
4949
#else

source/module_io/to_wannier90_lcao_in_pw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class toWannier90_LCAO_IN_PW : public toWannier90_PW
6868
protected:
6969
const Parallel_Orbitals* ParaV;
7070
/// @brief psi initializer for expanding nao in planewave basis
71-
psi_initializer<std::complex<double>, base_device::DEVICE_CPU>* psi_init_;
71+
PsiInitializer<std::complex<double>, base_device::DEVICE_CPU>* psi_init_;
7272

7373
/// @brief get Bloch function from LCAO wavefunction
7474
/// @param psi_in

source/module_psi/psi_init.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
#include "module_base/tool_quit.h"
66
#include "module_hsolver/diago_iter_assist.h"
77
#include "module_parameter/parameter.h"
8+
#include "module_psi/psi_initializer_random.h"
89
#include "module_psi/psi_initializer_atomic.h"
910
#include "module_psi/psi_initializer_atomic_random.h"
1011
#include "module_psi/psi_initializer_nao.h"
1112
#include "module_psi/psi_initializer_nao_random.h"
12-
#include "module_psi/psi_initializer_random.h"
1313
namespace psi
1414
{
1515

@@ -46,27 +46,27 @@ void PSIInit<T, Device>::prepare_init(Structure_Factor* p_sf,
4646
ModuleBase::timer::tick("PSIInit", "prepare_init");
4747
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
4848
{
49-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
49+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
5050
}
5151
else if (this->init_wfc == "atomic")
5252
{
53-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic<T, Device>());
53+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomic<T, Device>());
5454
}
5555
else if (this->init_wfc == "random")
5656
{
57-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
57+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
5858
}
5959
else if (this->init_wfc == "nao")
6060
{
61-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao<T, Device>());
61+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAO<T, Device>());
6262
}
6363
else if (this->init_wfc == "atomic+random")
6464
{
65-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic_random<T, Device>());
65+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomicRandom<T, Device>());
6666
}
6767
else if (this->init_wfc == "nao+random")
6868
{
69-
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao_random<T, Device>());
69+
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAORandom<T, Device>());
7070
}
7171
else
7272
{
@@ -110,7 +110,7 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
110110
// is not ready yet.
111111
if (this->use_psiinitializer) // new method
112112
{
113-
// psi_initializer drag initialization of pw wavefunction out of HSolver, make psi
113+
// PsiInitializer drag initialization of pw wavefunction out of HSolver, make psi
114114
// initialization decoupled with HSolver (diagonalization) procedure.
115115
// However, due to EXX is hard to maintain, we still use the old method for EXX.
116116
// LCAOINPW in version >= 3.5.0 uses this new method.
@@ -147,9 +147,10 @@ void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseud
147147
}
148148
}
149149

150+
// in the following function, the psi on Device will be initialized with the CPU psi
150151
template <typename T, typename Device>
151-
void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
152-
psi::Psi<T, Device>* kspw_psi,
152+
void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi, // the one always on CPU
153+
psi::Psi<T, Device>* kspw_psi, // the one may be on GPU. In CPU case, it is the same as psi
153154
hamilt::Hamilt<T, Device>* p_hamilt,
154155
const pseudopot_cell_vnl& nlpp,
155156
std::ofstream& ofs_running,
@@ -169,7 +170,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
169170
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
170171
for (int ik = 0; ik < this->pw_wfc->nks; ik++)
171172
{
172-
//! Fix the wavefunction to initialize at given kpoint
173+
//! Fix the wavefunction to initialize at given kpoint.
174+
// This will fix the kpoint for CPU case. For GPU, we should additionally call fix_k for kspw_psi
173175
psi->fix_k(ik);
174176

175177
//! Update Hamiltonian from other kpoint to the given one
@@ -179,20 +181,20 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
179181
//! and G is wavevector of the peroiodic part of the Bloch function
180182
this->psi_init->proj_ao_onkG(ik);
181183

182-
//! psi_initializer manages memory of psig with shared pointer,
184+
//! PsiInitializer manages memory of psig with shared pointer,
183185
//! its access to use is shared here via weak pointer
184-
//! therefore once the psi_initializer is destructed, psig will be destructed, too
186+
//! therefore once the PsiInitializer is destructed, psig will be destructed, too
185187
//! this way, we can avoid memory leak and undefined behavior
186-
std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
187-
188-
if (psig.expired())
188+
// std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
189+
psi::Psi<T, Device>* psig_ = this->psi_init->share_psig();
190+
if (/*psig.expired()*/ psig_ == nullptr)
189191
{
190192
ModuleBase::WARNING_QUIT("PSIInit::initialize_psi", "psig lifetime is expired");
191193
}
192194

193195
//! to use psig, we need to lock it to get a shared pointer version,
194196
//! then switch kpoint of psig to the given one
195-
auto psig_ = psig.lock();
197+
// auto psig_ = psig.lock();
196198
// CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
197199
// so we can only allocate memory for one kpoint with the maximal number of pw
198200
// over all kpoints, then the memory space will be always enough. Then for each
@@ -210,6 +212,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
210212
if (((this->ks_solver == "cg") || (this->ks_solver == "lapack")) && (this->basis_type == "pw"))
211213
{
212214
// the following function is only run serially, to be improved
215+
// For GPU: this psig_ should be on GPU before calling the following function
213216
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
214217
psig_->get_pointer(),
215218
psig_->get_nbands(),
@@ -218,6 +221,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
218221
etatom.data());
219222
continue;
220223
}
224+
// do nothing in LCAO_IN_PW case because psig is used to do transformation instead of initialization
221225
else if ((this->ks_solver == "lapack") && (this->basis_type == "lcao_in_pw"))
222226
{
223227
if (ik == 0)
@@ -239,6 +243,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
239243
}
240244

241245
// for the Davidson method, we just copy the wavefunction (partially)
246+
// For GPU: although this is simply the copy operation, if GPU present, this should be a data sending operation
242247
for (int iband = 0; iband < kspw_psi->get_nbands(); iband++)
243248
{
244249
for (int ibasis = 0; ibasis < kspw_psi->get_nbasis(); ibasis++)
@@ -248,7 +253,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
248253
}
249254
} // end k-point loop
250255

251-
if (this->basis_type != "lcao_in_pw")
256+
if (this->basis_type != "lcao_in_pw") // if not LCAO_IN_PW case, we can release the memory of psig after initailization is done.
252257
{
253258
this->psi_init->deallocate_psig();
254259
}

source/module_psi/psi_init.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class PSIInit
4141
// make interpolate table
4242
void make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell);
4343

44-
//------------------------ only for psi_initializer --------------------
44+
//------------------------ only for PsiInitializer --------------------
4545
/**
4646
* @brief initialize the wavefunction
4747
*
@@ -58,27 +58,27 @@ class PSIInit
5858
const bool is_already_initpsi);
5959

6060
/**
61-
* @brief get the psi_initializer
61+
* @brief get the PsiInitializer
6262
*
63-
* @return psi_initializer<T, Device>*
63+
* @return PsiInitializer<T, Device>*
6464
*/
65-
std::weak_ptr<psi::Psi<T, Device>> get_psig() const
65+
psi::Psi<T, Device>* get_psig() const
6666
{
6767
return this->psi_init->share_psig();
6868
}
6969
//----------------------------------------------------------------------
7070

7171
private:
72-
// psi_initializer<T, Device>* psi_init = nullptr;
72+
// PsiInitializer<T, Device>* psi_init = nullptr;
7373
// change to use smart pointer to manage the memory, and avoid memory leak
7474
// while the std::make_unique() is not supported till C++14,
7575
// so use the new and std::unique_ptr to manage the memory, but this makes new-delete not symmetric
76-
std::unique_ptr<psi_initializer<T, Device>> psi_init;
76+
std::unique_ptr<PsiInitializer<T, Device>> psi_init;
7777

7878
//! temporary: wave functions, this one may be deleted in future
7979
wavefunc wf_old;
8080

81-
// whether to use psi_initializer
81+
// whether to usePsiInitializer
8282
bool use_psiinitializer = false;
8383

8484
// wavefunction initialization type

source/module_psi/psi_initializer.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#endif
1313

1414
template<typename T, typename Device>
15-
psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(const bool only_psig)
15+
psi::Psi<std::complex<double>>* PsiInitializer<T, Device>::allocate(const bool only_psig)
1616
{
17-
ModuleBase::timer::tick("psi_initializer", "allocate");
17+
ModuleBase::timer::tick("PsiInitializer", "allocate");
1818
/*
1919
WARNING: when basis_type = "pw", the variable PARAM.globalv.nlocal will also be set, in this case, it is set to
2020
9 = 1 + 3 + 5, which is the maximal number of orbitals spd, I don't think it is reasonable
@@ -82,13 +82,17 @@ psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(const bool
8282
// std::cout << " MEMORY FOR PSI PER PROCESSOR (MB) : " << double(memory_cost_psi)/1024.0/1024.0 << std::endl;
8383
ModuleBase::Memory::record("Psi_PW", memory_cost_psi);
8484
}
85-
// psi_initializer also works for basis transformation tasks. In this case, psig needs to allocate memory for
85+
// PsiInitializer also works for basis transformation tasks. In this case, psig needs to allocate memory for
8686
// each kpoint, otherwise, for initializing pw wavefunction, only one kpoint's space is enough.
8787
const int nks_psig = (PARAM.inp.basis_type == "pw")? 1 : nks_psi;
88-
this->psig_ = std::make_shared<psi::Psi<T, Device>>(nks_psig,
89-
nbands_actual,
90-
nbasis_actual,
91-
this->pw_wfc_->npwk);
88+
// this->psig_ = std::make_shared<psi::Psi<T, Device>>(nks_psig,
89+
// nbands_actual,
90+
// nbasis_actual,
91+
// this->pw_wfc_->npwk);
92+
this->d_psig_ = new psi::Psi<T, Device>(nks_psig,
93+
nbands_actual,
94+
nbasis_actual,
95+
this->pw_wfc_->npwk);
9296

9397
double memory_cost_psig =
9498
nks_psig * nbands_actual * this->pw_wfc_->npwk_max * PARAM.globalv.npol * sizeof(T);
@@ -111,14 +115,14 @@ psi::Psi<std::complex<double>>* psi_initializer<T, Device>::allocate(const bool
111115
<< "npwk_max = " << this->pw_wfc_->npwk_max << "\n"
112116
<< "npol = " << PARAM.globalv.npol << "\n";
113117
ModuleBase::Memory::record("psigPW", memory_cost_psig);
114-
ModuleBase::timer::tick("psi_initializer", "allocate");
118+
ModuleBase::timer::tick("PsiInitializer", "allocate");
115119
return psi_out;
116120
}
117121

118122
template<typename T, typename Device>
119-
void psi_initializer<T, Device>::random_t(T* psi, const int iw_start, const int iw_end, const int ik)
123+
void PsiInitializer<T, Device>::random_t(T* psi, const int iw_start, const int iw_end, const int ik)
120124
{
121-
ModuleBase::timer::tick("psi_initializer", "random_t");
125+
ModuleBase::timer::tick("PsiInitializer", "random_t");
122126
assert(iw_start >= 0);
123127
const int ng = this->pw_wfc_->npwk[ik];
124128

@@ -213,14 +217,14 @@ void psi_initializer<T, Device>::random_t(T* psi, const int iw_start, const int
213217
}
214218
}
215219
}
216-
ModuleBase::timer::tick("psi_initializer_random", "random_t");
220+
ModuleBase::timer::tick("PsiInitializer", "random_t");
217221
}
218222

219223
#ifdef __MPI
220224
template<typename T, typename Device>
221-
void psi_initializer<T, Device>::stick_to_pool(Real* stick, const int& ir, Real* out) const
225+
void PsiInitializer<T, Device>::stick_to_pool(Real* stick, const int& ir, Real* out) const
222226
{
223-
ModuleBase::timer::tick("psi_initializer", "stick_to_pool");
227+
ModuleBase::timer::tick("PsiInitializer", "stick_to_pool");
224228
MPI_Status ierror;
225229
const int is = this->ixy2is_[ir];
226230
const int ip = this->pw_wfc_->fftixy2ip[ir];
@@ -245,7 +249,7 @@ void psi_initializer<T, Device>::stick_to_pool(Real* stick, const int& ir, Real*
245249
}
246250
else
247251
{
248-
ModuleBase::WARNING_QUIT("psi_initializer", "stick_to_pool: Real type not supported");
252+
ModuleBase::WARNING_QUIT("PsiInitializer", "stick_to_pool: Real type not supported");
249253
}
250254
for(int iz=0; iz<nz; iz++)
251255
{
@@ -264,25 +268,25 @@ void psi_initializer<T, Device>::stick_to_pool(Real* stick, const int& ir, Real*
264268
}
265269
else
266270
{
267-
ModuleBase::WARNING_QUIT("psi_initializer", "stick_to_pool: Real type not supported");
271+
ModuleBase::WARNING_QUIT("PsiInitializer", "stick_to_pool: Real type not supported");
268272
}
269273
}
270274

271275
return;
272-
ModuleBase::timer::tick("psi_initializer", "stick_to_pool");
276+
ModuleBase::timer::tick("PsiInitializer", "stick_to_pool");
273277
}
274278
#endif
275279

276280
// explicit instantiation
277-
template class psi_initializer<std::complex<double>, base_device::DEVICE_CPU>;
278-
template class psi_initializer<std::complex<float>, base_device::DEVICE_CPU>;
281+
template class PsiInitializer<std::complex<double>, base_device::DEVICE_CPU>;
282+
template class PsiInitializer<std::complex<float>, base_device::DEVICE_CPU>;
279283
// gamma point calculation
280-
template class psi_initializer<double, base_device::DEVICE_CPU>;
281-
template class psi_initializer<float, base_device::DEVICE_CPU>;
284+
template class PsiInitializer<double, base_device::DEVICE_CPU>;
285+
template class PsiInitializer<float, base_device::DEVICE_CPU>;
282286
#if ((defined __CUDA) || (defined __ROCM))
283-
template class psi_initializer<std::complex<double>, base_device::DEVICE_GPU>;
284-
template class psi_initializer<std::complex<float>, base_device::DEVICE_GPU>;
287+
template class PsiInitializer<std::complex<double>, base_device::DEVICE_GPU>;
288+
template class PsiInitializer<std::complex<float>, base_device::DEVICE_GPU>;
285289
// gamma point calculation
286-
template class psi_initializer<double, base_device::DEVICE_GPU>;
287-
template class psi_initializer<float, base_device::DEVICE_GPU>;
290+
template class PsiInitializer<double, base_device::DEVICE_GPU>;
291+
template class PsiInitializer<float, base_device::DEVICE_GPU>;
288292
#endif

0 commit comments

Comments
 (0)