Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions source/module_esolver/esolver_ks_lcaopw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ namespace ModuleESolver
// It is not a good choice to overload another solve function here, this will spoil the concept of
// multiple inheritance and polymorphism. But for now, we just do it in this way.
// In the future, there will be a series of class ESolver_KS_LCAO_PW, HSolver_LCAO_PW and so on.
std::weak_ptr<psi::Psi<T>> psig = this->p_wf_init->get_psig();
auto psig = this->p_wf_init->get_psig();

if (psig.expired())
if (/*psig.expired()*/ psig == nullptr)
{
ModuleBase::WARNING_QUIT("ESolver_KS_PW::hamilt2density_single", "psig lifetime is expired");
}

hsolver::HSolverLIP<T> hsolver_lip_obj(this->pw_wfc);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig.lock().get()[0], skip_charge,ucell.tpiba,ucell.nat);
hsolver_lip_obj.solve(this->p_hamilt, this->kspw_psi[0], this->pelec, psig[0], skip_charge, ucell.tpiba, ucell.nat);

// add exx
#ifdef __EXX
Expand Down
2 changes: 1 addition & 1 deletion source/module_esolver/esolver_ks_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class ESolver_KS_PW : public ESolver_KS<T, Device>
//! hide the psi in ESolver_KS for tmp use
psi::Psi<std::complex<double>, base_device::DEVICE_CPU>* psi = nullptr;

// psi_initializer controller
// PsiInitializer controller
psi::PSIInit<T, Device>* p_wf_init = nullptr;

Device* ctx = {};
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/read_input_item_postprocess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void ReadInput::item_postprocess()

In the future lcao_in_pw will have its own ESolver.

2023/12/22 use new psi_initializer to expand numerical
2023/12/22 use new PsiInitializer to expand numerical
atomic orbitals, ykhuang
*/
if (para.input.towannier90 && para.input.basis_type == "lcao_in_pw")
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/read_input_item_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ void ReadInput::item_system()
}
{
Input_Item item("psi_initializer");
item.annotation = "whether to use psi_initializer";
item.annotation = "whether to use PsiInitializer";
item.reset_value = [](const Input_Item& item, Parameter& para) {
if (para.input.basis_type == "lcao_in_pw")
{
Expand Down
8 changes: 4 additions & 4 deletions source/module_io/to_wannier90_lcao_in_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void toWannier90_LCAO_IN_PW::calculate(

Structure_Factor* sf_ptr = const_cast<Structure_Factor*>(&sf);
ModulePW::PW_Basis_K* wfcpw_ptr = const_cast<ModulePW::PW_Basis_K*>(wfcpw);
this->psi_init_ = new psi_initializer_nao<std::complex<double>, base_device::DEVICE_CPU>();
this->psi_init_ = new PsiInitializerNAO<std::complex<double>, base_device::DEVICE_CPU>();
#ifdef __MPI
this->psi_init_->initialize(sf_ptr, wfcpw_ptr, &ucell, &(GlobalC::Pkpoints), 1, nullptr, GlobalV::MY_RANK);
#else
Expand Down Expand Up @@ -218,16 +218,16 @@ void toWannier90_LCAO_IN_PW::nao_G_expansion(
{
int npwx = wfcpw->npwk_max;
this->psi_init_->proj_ao_onkG(ik);
std::weak_ptr<psi::Psi<std::complex<double>>> psig = this->psi_init_->share_psig();
if(psig.expired()) { ModuleBase::WARNING_QUIT("toWannier90_LCAO_IN_PW::nao_G_expansion", "psig is expired");
auto psig = this->psi_init_->share_psig();
if(/*psig.expired()*/ psig == nullptr) { ModuleBase::WARNING_QUIT("toWannier90_LCAO_IN_PW::nao_G_expansion", "psig is expired");
}
int nbands = PARAM.globalv.nlocal;
int nbasis = npwx*PARAM.globalv.npol;
for (int ib = 0; ib < nbands; ib++)
{
for (int ig = 0; ig < nbasis; ig++)
{
psi(ib, ig) = psig.lock().get()[0](ik, ib, ig);
psi(ib, ig) = psig[0](ik, ib, ig);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion source/module_io/to_wannier90_lcao_in_pw.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class toWannier90_LCAO_IN_PW : public toWannier90_PW
protected:
const Parallel_Orbitals* ParaV;
/// @brief psi initializer for expanding nao in planewave basis
psi_initializer<std::complex<double>, base_device::DEVICE_CPU>* psi_init_;
PsiInitializer<std::complex<double>, base_device::DEVICE_CPU>* psi_init_;

/// @brief get Bloch function from LCAO wavefunction
/// @param psi_in
Expand Down
154 changes: 98 additions & 56 deletions source/module_psi/psi_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,35 @@ PSIInit<T, Device>::PSIInit(const std::string& init_wfc_in,
this->basis_type = basis_type_in;
this->use_psiinitializer = use_psiinitializer_in;
this->pw_wfc = pw_wfc_in;

if (PARAM.inp.psi_initializer == true)
{
this->init_psi_method = "new";
}
else
{
if (PARAM.inp.init_wfc == "file" || PARAM.inp.device == "gpu" || PARAM.inp.esolver_type == "sdft")
{
this->init_psi_method = "old"; // old method;
}
else
{
this->init_psi_method = "new"; // new method;
}
}
}

template <typename T, typename Device>
PSIInit<T, Device>::~PSIInit()
{
if (this->init_psi_method == "new")
{
{
this->psi_init->deallocate_psig();
// delete this->psi_init;
// this->psi_init = nullptr;
}
}
}

template <typename T, typename Device>
Expand All @@ -37,53 +66,56 @@ void PSIInit<T, Device>::prepare_init(Structure_Factor* p_sf,
#endif
pseudopot_cell_vnl* p_ppcell)
{
if (!this->use_psiinitializer)
if (this->init_psi_method == "old")
{
return;
}
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
ModuleBase::timer::tick("PSIInit", "prepare_init");
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
}
else if (this->init_wfc == "atomic")
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic<T, Device>());
}
else if (this->init_wfc == "random")
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_random<T, Device>());
}
else if (this->init_wfc == "nao")
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao<T, Device>());
}
else if (this->init_wfc == "atomic+random")
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_atomic_random<T, Device>());
}
else if (this->init_wfc == "nao+random")
{
this->psi_init = std::unique_ptr<psi_initializer<T, Device>>(new psi_initializer_nao_random<T, Device>());
}
else
{
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
}
// under restriction of C++11, std::unique_ptr can not be allocate via std::make_unique
// use new instead, but will cause asymmetric allocation and deallocation, in literal aspect
ModuleBase::timer::tick("PSIInit", "prepare_init");
if ((this->init_wfc.substr(0, 6) == "atomic") && (p_ucell->natomwfc == 0))
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
}
else if (this->init_wfc == "atomic")
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomic<T, Device>());
}
else if (this->init_wfc == "random")
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerRandom<T, Device>());
}
else if (this->init_wfc == "nao")
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAO<T, Device>());
}
else if (this->init_wfc == "atomic+random")
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerAtomicRandom<T, Device>());
}
else if (this->init_wfc == "nao+random")
{
this->psi_init = std::unique_ptr<PsiInitializer<T, Device>>(new PsiInitializerNAORandom<T, Device>());
}
else
{
ModuleBase::WARNING_QUIT("PSIInit::prepare_init", "for new psi initializer, init_wfc type not supported");
}

//! function polymorphism is moved from constructor to function initialize.
//! Two slightly different implementation are for MPI and serial case, respectively.
//! function polymorphism is moved from constructor to function initialize.
//! Two slightly different implementation are for MPI and serial case, respectively.
#ifdef __MPI
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, p_parak, random_seed, p_ppcell, rank);
#else
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
this->psi_init->initialize(p_sf, pw_wfc, p_ucell, random_seed, p_ppcell);
#endif

// always new->initialize->tabulate->allocate->proj_ao_onkG
this->psi_init->tabulate();
ModuleBase::timer::tick("PSIInit", "prepare_init");
// always new->initialize->tabulate->allocate->proj_ao_onkG
this->psi_init->tabulate();
ModuleBase::timer::tick("PSIInit", "prepare_init");
}
}

template <typename T, typename Device>
Expand All @@ -108,9 +140,9 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
// the basis (representation) with operator (hamiltonian) and solver (diagonalization).
// This feature requires feasible Linear Algebra library in-built in ABACUS, which
// is not ready yet.
if (this->use_psiinitializer) // new method
if (this->init_psi_method == "new") // new method
{
// psi_initializer drag initialization of pw wavefunction out of HSolver, make psi
// PsiInitializer drag initialization of pw wavefunction out of HSolver, make psi
// initialization decoupled with HSolver (diagonalization) procedure.
// However, due to EXX is hard to maintain, we still use the old method for EXX.
// LCAOINPW in version >= 3.5.0 uses this new method.
Expand All @@ -137,27 +169,31 @@ void PSIInit<T, Device>::allocate_psi(Psi<std::complex<double>>*& psi,
template <typename T, typename Device>
void PSIInit<T, Device>::make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell)
{
if (this->use_psiinitializer)
if (this->init_psi_method == "new")
{
} // do not need to do anything because the interpolate table is unchanged
else // old initialization method, used in EXX calculation
{
this->wf_old.init_after_vc(nks); // reallocate wanf2, the planewave expansion of lcao
this->wf_old.init_at_1(p_sf, &p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
this->wf_old.init_at_1(
p_sf,
&p_ppcell->tab_at); // re-calculate tab_at, the overlap matrix between atomic pswfc and jlq
}
}

// in the following function, the psi on Device will be initialized with the CPU psi
template <typename T, typename Device>
void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
psi::Psi<T, Device>* kspw_psi,
hamilt::Hamilt<T, Device>* p_hamilt,
const pseudopot_cell_vnl& nlpp,
std::ofstream& ofs_running,
const bool is_already_initpsi)
void PSIInit<T, Device>::initialize_psi(
Psi<std::complex<double>>* psi, // the one always on CPU
psi::Psi<T, Device>* kspw_psi, // the one may be on GPU. In CPU case, it is the same as psi
hamilt::Hamilt<T, Device>* p_hamilt,
const pseudopot_cell_vnl& nlpp,
std::ofstream& ofs_running,
const bool is_already_initpsi)
{
ModuleBase::timer::tick("PSIInit", "initialize_psi");

if (PARAM.inp.psi_initializer)
if (this->init_psi_method == "new")
{
// if psig is not allocated before, allocate it
if (!this->psi_init->psig_use_count())
Expand All @@ -169,7 +205,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
// like (1, nbands, npwx), in which npwx is the maximal npw of all kpoints
for (int ik = 0; ik < this->pw_wfc->nks; ik++)
{
//! Fix the wavefunction to initialize at given kpoint
//! Fix the wavefunction to initialize at given kpoint.
// This will fix the kpoint for CPU case. For GPU, we should additionally call fix_k for kspw_psi
psi->fix_k(ik);

//! Update Hamiltonian from other kpoint to the given one
Expand All @@ -179,20 +216,20 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
//! and G is wavevector of the peroiodic part of the Bloch function
this->psi_init->proj_ao_onkG(ik);

//! psi_initializer manages memory of psig with shared pointer,
//! PsiInitializer manages memory of psig with shared pointer,
//! its access to use is shared here via weak pointer
//! therefore once the psi_initializer is destructed, psig will be destructed, too
//! therefore once the PsiInitializer is destructed, psig will be destructed, too
//! this way, we can avoid memory leak and undefined behavior
std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();

if (psig.expired())
// std::weak_ptr<psi::Psi<T, Device>> psig = this->psi_init->share_psig();
psi::Psi<T, Device>* psig_ = this->psi_init->share_psig();
if (/*psig.expired()*/ psig_ == nullptr)
{
ModuleBase::WARNING_QUIT("PSIInit::initialize_psi", "psig lifetime is expired");
}

//! to use psig, we need to lock it to get a shared pointer version,
//! then switch kpoint of psig to the given one
auto psig_ = psig.lock();
// auto psig_ = psig.lock();
// CHANGE LOG: if not lcaoinpw, the psig will only be used in psi-initialization
// so we can only allocate memory for one kpoint with the maximal number of pw
// over all kpoints, then the memory space will be always enough. Then for each
Expand All @@ -210,6 +247,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
if (((this->ks_solver == "cg") || (this->ks_solver == "lapack")) && (this->basis_type == "pw"))
{
// the following function is only run serially, to be improved
// For GPU: this psig_ should be on GPU before calling the following function
hsolver::DiagoIterAssist<T, Device>::diagH_subspace_init(p_hamilt,
psig_->get_pointer(),
psig_->get_nbands(),
Expand All @@ -218,6 +256,7 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
etatom.data());
continue;
}
// do nothing in LCAO_IN_PW case because psig is used to do transformation instead of initialization
else if ((this->ks_solver == "lapack") && (this->basis_type == "lcao_in_pw"))
{
if (ik == 0)
Expand All @@ -239,6 +278,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}

// for the Davidson method, we just copy the wavefunction (partially)
// For GPU: although this is simply the copy operation, if GPU present, this should be a data sending
// operation
for (int iband = 0; iband < kspw_psi->get_nbands(); iband++)
{
for (int ibasis = 0; ibasis < kspw_psi->get_nbasis(); ibasis++)
Expand All @@ -248,7 +289,8 @@ void PSIInit<T, Device>::initialize_psi(Psi<std::complex<double>>* psi,
}
} // end k-point loop

if (this->basis_type != "lcao_in_pw")
if (this->basis_type
!= "lcao_in_pw") // if not LCAO_IN_PW case, we can release the memory of psig after initailization is done.
{
this->psi_init->deallocate_psig();
}
Expand Down
18 changes: 10 additions & 8 deletions source/module_psi/psi_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class PSIInit
const std::string& basis_type_in,
const bool& use_psiinitializer_in,
ModulePW::PW_Basis_K* pw_wfc_in);
~PSIInit(){};
~PSIInit();

// prepare the wavefunction initialization
void prepare_init(Structure_Factor* p_sf, //< structure factor
Expand All @@ -41,7 +41,7 @@ class PSIInit
// make interpolate table
void make_table(const int nks, Structure_Factor* p_sf, pseudopot_cell_vnl* p_ppcell);

//------------------------ only for psi_initializer --------------------
//------------------------ only for PsiInitializer --------------------
/**
* @brief initialize the wavefunction
*
Expand All @@ -58,27 +58,27 @@ class PSIInit
const bool is_already_initpsi);

/**
* @brief get the psi_initializer
* @brief get the PsiInitializer
*
* @return psi_initializer<T, Device>*
* @return PsiInitializer<T, Device>*
*/
std::weak_ptr<psi::Psi<T, Device>> get_psig() const
psi::Psi<T, Device>* get_psig() const
{
return this->psi_init->share_psig();
}
//----------------------------------------------------------------------

private:
// psi_initializer<T, Device>* psi_init = nullptr;
// PsiInitializer<T, Device>* psi_init = nullptr;
// change to use smart pointer to manage the memory, and avoid memory leak
// while the std::make_unique() is not supported till C++14,
// so use the new and std::unique_ptr to manage the memory, but this makes new-delete not symmetric
std::unique_ptr<psi_initializer<T, Device>> psi_init;
std::unique_ptr<PsiInitializer<T, Device>> psi_init;

//! temporary: wave functions, this one may be deleted in future
wavefunc wf_old;

// whether to use psi_initializer
// whether to usePsiInitializer
bool use_psiinitializer = false;

// wavefunction initialization type
Expand All @@ -94,6 +94,8 @@ class PSIInit
ModulePW::PW_Basis_K* pw_wfc = nullptr;

Device* ctx = {};

std::string init_psi_method = "old";
};

} // namespace psi
Expand Down
Loading
Loading