Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
deepks_descriptor.o\
deepks_force.o\
deepks_fpre.o\
deepks_iterate.o\
deepks_spre.o\
deepks_orbital.o\
deepks_orbpre.o\
Expand Down
7 changes: 4 additions & 3 deletions source/module_esolver/esolver_ks_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa

#ifdef __DEEPKS
// 10) initialize deepks
LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld);
LCAO_domain::DeePKS_init(ucell, pv, this->kv.get_nks(), orb_, this->ld, GlobalV::ofs_running);
if (PARAM.inp.deepks_scf)
{
// load the DeePKS model from deep neural network
Expand All @@ -220,6 +220,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
PARAM.inp.deepks_equiv,
ld.init_pdm,
ucell.nat,
orb_.Alpha[0].getTotal_nchi() * ucell.nat,
ld.lmaxd,
ld.inl_l,
Expand All @@ -245,8 +246,8 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
"%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%"
"%%%%%%%%%%%%%%%%%%%%%%%%%%"
<< std::endl;
std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar (" << PARAM.globalv.kpar_lcao
<< ")." << std::endl;
std::cout << " Warning: nks (" << this->kv.get_nks() << ") is not divisible by kpar ("
<< PARAM.globalv.kpar_lcao << ")." << std::endl;
std::cout << " This may lead to poor load balance. It is strongly suggested to" << std::endl;
std::cout << " set nks to be divisible by kpar, but if this is really what" << std::endl;
std::cout << " you want, please ignore this warning." << std::endl;
Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_ks_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
// for grid integration
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
#include "module_hamilt_lcao/module_gint/gint_k.h"
#include "module_hamilt_lcao/module_gint/temp_gint/gint_info.h"
#include "module_hamilt_lcao/module_gint/temp_gint/gint.h"
#include "module_hamilt_lcao/module_gint/temp_gint/gint_info.h"
#ifdef __DEEPKS
#include "module_hamilt_lcao/module_deepks/LCAO_deepks.h"
#endif
Expand Down Expand Up @@ -100,7 +100,7 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
//---------------------------------------------------------------------

#ifdef __DEEPKS
LCAO_Deepks ld;
LCAO_Deepks<TK> ld;
#endif

#ifdef __EXX
Expand Down
5 changes: 3 additions & 2 deletions source/module_esolver/lcao_after_scf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
if (this->psi != nullptr && (istep % PARAM.inp.out_interval == 0))
{
hamilt::HamiltLCAO<TK, TR>* p_ham_deepks = dynamic_cast<hamilt::HamiltLCAO<TK, TR>*>(this->p_hamilt);
std::shared_ptr<LCAO_Deepks> ld_shared_ptr(&ld, [](LCAO_Deepks*) {});
std::shared_ptr<LCAO_Deepks<TK>> ld_shared_ptr(&ld, [](LCAO_Deepks<TK>*) {});
LCAO_Deepks_Interface<TK, TR> deepks_interface(ld_shared_ptr);

deepks_interface.out_deepks_labels(this->pelec->f_en.etot,
Expand All @@ -229,7 +229,8 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(UnitCell& ucell, const int istep, const
&(this->pv),
*(this->psi),
dynamic_cast<const elecstate::ElecStateLCAO<TK>*>(this->pelec)->get_DM(),
p_ham_deepks);
p_ham_deepks,
GlobalV::MY_RANK);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/hamilt_lcaodft/FORCE.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Force_LCAO
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<T>& ld,
#endif
typename TGint<T>::type& gint,
const TwoCenterBundle& two_center_bundle,
Expand Down
6 changes: 3 additions & 3 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ void Force_Stress_LCAO<T>::getForceStress(UnitCell& ucell,
ModulePW::PW_Basis* rhopw,
surchem& solvent,
#ifdef __DEEPKS
LCAO_Deepks& ld,
LCAO_Deepks<T>& ld,
#endif
#ifdef __EXX
Exx_LRI<double>& exx_lri_double,
Expand Down Expand Up @@ -838,7 +838,7 @@ void Force_Stress_LCAO<double>::integral_part(const bool isGammaOnly,
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<double>& ld,
#endif
Gint_Gamma& gint_gamma, // mohan add 2024-04-01
Gint_k& gint_k, // mohan add 2024-04-01
Expand Down Expand Up @@ -895,7 +895,7 @@ void Force_Stress_LCAO<std::complex<double>>::integral_part(const bool isGammaOn
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<std::complex<double>>& ld,
#endif
Gint_Gamma& gint_gamma,
Gint_k& gint_k,
Expand Down
4 changes: 2 additions & 2 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_STRESS.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Force_Stress_LCAO
ModulePW::PW_Basis* rhopw,
surchem& solvent,
#ifdef __DEEPKS
LCAO_Deepks& ld,
LCAO_Deepks<T>& ld,
#endif
#ifdef __EXX
Exx_LRI<double>& exx_lri_double,
Expand Down Expand Up @@ -99,7 +99,7 @@ class Force_Stress_LCAO
#if __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<T>& ld,
#endif
Gint_Gamma& gint_gamma,
Gint_k& gint_k,
Expand Down
18 changes: 2 additions & 16 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<double>& ld,
#endif
TGint<double>::type& gint,
const TwoCenterBundle& two_center_bundle,
Expand Down Expand Up @@ -252,21 +252,7 @@ void Force_LCAO<double>::ftable(const bool isforce,
{
const std::vector<std::vector<double>>& dm_gamma = dm->get_DMK_vector();

// These calculations have been done in LCAO_Deepks_Interface in after_scf
// std::vector<torch::Tensor> descriptor;
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
// ld.lmaxd,
// ld.nmaxd,
// ld.inlmax,
// ld.des_per_atom,
// ld.inl_l,
// descriptor,
// ld.pdm,
// ld.model_deepks,
// ld.gedm,
// ld.E_delta);

// No need to update E_delta here since it have been done in LCAO_Deepks_Interface in after_scf
const int nks = 1;
DeePKS_domain::cal_f_delta<double>(dm_gamma,
ucell,
Expand Down
18 changes: 2 additions & 16 deletions source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
#ifdef __DEEPKS
ModuleBase::matrix& fvnl_dalpha,
ModuleBase::matrix& svnl_dalpha,
LCAO_Deepks& ld,
LCAO_Deepks<std::complex<double>>& ld,
#endif
TGint<std::complex<double>>::type& gint,
const TwoCenterBundle& two_center_bundle,
Expand Down Expand Up @@ -347,21 +347,7 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
{
const std::vector<std::vector<std::complex<double>>>& dm_k = dm->get_DMK_vector();

// These calculations have been done in LCAO_Deepks_Interface in after_scf
// std::vector<torch::Tensor> descriptor;
// DeePKS_domain::cal_descriptor(ucell.nat, ld.inlmax, ld.inl_l, ld.pdm, descriptor, ld.des_per_atom);
// DeePKS_domain::cal_edelta_gedm(ucell.nat,
// ld.lmaxd,
// ld.nmaxd,
// ld.inlmax,
// ld.des_per_atom,
// ld.inl_l,
// descriptor,
// ld.pdm,
// ld.model_deepks,
// ld.gedm,
// ld.E_delta);

// No need to update E_delta since it have been done in LCAO_Deepks_Interface in after_scf
DeePKS_domain::cal_f_delta<std::complex<double>>(dm_k,
ucell,
orb,
Expand Down
20 changes: 18 additions & 2 deletions source/module_hamilt_lcao/hamilt_lcaodft/LCAO_allocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ namespace LCAO_domain
{
#ifdef __DEEPKS
// It seems it is only related to DeePKS, so maybe we should move it to DeeKS_domain
template <typename T>
void DeePKS_init(const UnitCell& ucell,
Parallel_Orbitals& pv,
const int& nks,
const LCAO_Orbitals& orb,
LCAO_Deepks& ld)
LCAO_Deepks<T>& ld,
std::ofstream& ofs)
{
ModuleBase::TITLE("LCAO_domain", "DeePKS_init");
// preparation for DeePKS
Expand All @@ -26,7 +28,7 @@ void DeePKS_init(const UnitCell& ucell,
na[it] = ucell.atoms[it].na;
}

ld.init(orb, ucell.nat, ucell.ntype, nks, pv, na);
ld.init(orb, ucell.nat, ucell.ntype, nks, pv, na, ofs);

if (PARAM.inp.deepks_scf)
{
Expand All @@ -35,5 +37,19 @@ void DeePKS_init(const UnitCell& ucell,
}
return;
}

template void DeePKS_init<double>(const UnitCell& ucell,
Parallel_Orbitals& pv,
const int& nks,
const LCAO_Orbitals& orb,
LCAO_Deepks<double>& ld,
std::ofstream& ofs);

template void DeePKS_init<std::complex<double>>(const UnitCell& ucell,
Parallel_Orbitals& pv,
const int& nks,
const LCAO_Orbitals& orb,
LCAO_Deepks<std::complex<double>>& ld,
std::ofstream& ofs);
#endif
} // namespace LCAO_domain
4 changes: 3 additions & 1 deletion source/module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ void build_ST_new(ForceStressArrays& fsr,
void zeros_HSR(const char& mtype, LCAO_HS_Arrays& HS_arrays);

#ifdef __DEEPKS
template <typename T>
void DeePKS_init(const UnitCell& ucell,
Parallel_Orbitals& pv,
const int& nks,
const LCAO_Orbitals& orb,
LCAO_Deepks& ld);
LCAO_Deepks<T>& ld,
std::ofstream& ofs);
#endif

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ HamiltLCAO<TK, TR>::HamiltLCAO(Gint_Gamma* GG_in,
elecstate::DensityMatrix<TK, double>* DM_in
#ifdef __DEEPKS
,
LCAO_Deepks* ld_in
LCAO_Deepks<TK>* ld_in
#endif
#ifdef __EXX
,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class HamiltLCAO : public Hamilt<TK>
elecstate::DensityMatrix<TK, double>* DM_in
#ifdef __DEEPKS
,
LCAO_Deepks* ld_in
LCAO_Deepks<TK>* ld_in
#endif
#ifdef __EXX
,
Expand Down
Loading
Loading