Skip to content

Commit a5e3109

Browse files
committed
Remove global dependence of pdm related functions in DeePKS.
1 parent d6ddec8 commit a5e3109

File tree

13 files changed

+201
-118
lines changed

13 files changed

+201
-118
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ OBJS_DEEPKS=LCAO_deepks.o\
201201
deepks_vdelta.o\
202202
deepks_vdpre.o\
203203
deepks_hmat.o\
204+
deepks_pdm.o\
204205
LCAO_deepks_io.o\
205-
LCAO_deepks_pdm.o\
206206
LCAO_deepks_phialpha.o\
207207
LCAO_deepks_interface.o\
208208

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,14 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
227227
// load the DeePKS model from deep neural network
228228
DeePKS_domain::load_model(PARAM.inp.deepks_model, GlobalC::ld.model_deepks);
229229
// read pdm from file for NSCF or SCF-restart, do it only once in whole calculation
230-
GlobalC::ld.read_projected_DM((PARAM.inp.init_chg == "file"), PARAM.inp.deepks_equiv, *orb_.Alpha);
230+
DeePKS_domain::read_pdm((PARAM.inp.init_chg == "file"),
231+
PARAM.inp.deepks_equiv,
232+
GlobalC::ld.init_pdm,
233+
GlobalC::ld.inlmax,
234+
GlobalC::ld.lmaxd,
235+
GlobalC::ld.inl_l,
236+
*orb_.Alpha,
237+
GlobalC::ld.pdm);
231238
}
232239
#endif
233240

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_gamma.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ void Force_LCAO<double>::ftable(const bool isforce,
252252
if (PARAM.inp.deepks_scf)
253253
{
254254
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
255-
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
256-
257255
DeePKS_domain::cal_descriptor(ucell.nat,
258256
GlobalC::ld.inlmax,
259257
GlobalC::ld.inl_l,

source/module_hamilt_lcao/hamilt_lcaodft/FORCE_k.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,6 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
347347
const std::vector<std::vector<std::complex<double>>>& dm_k = dm->get_DMK_vector();
348348

349349
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
350-
// GlobalC::ld.cal_projected_DM(dm, ucell, orb, gd);
351-
352350
std::vector<torch::Tensor> descriptor;
353351
DeePKS_domain::cal_descriptor(ucell.nat,
354352
GlobalC::ld.inlmax,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/deepks_lcao.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,18 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
160160
{
161161
ModuleBase::timer::tick("DeePKS", "contributeHR");
162162

163-
GlobalC::ld.cal_projected_DM<TK>(this->DM, *this->ucell, *ptr_orb_, *(this->gd));
163+
DeePKS_domain::cal_pdm<TK>(GlobalC::ld.init_pdm,
164+
GlobalC::ld.inlmax,
165+
GlobalC::ld.lmaxd,
166+
GlobalC::ld.inl_l,
167+
GlobalC::ld.inl_index,
168+
this->DM,
169+
GlobalC::ld.phialpha,
170+
*this->ucell,
171+
*ptr_orb_,
172+
*(this->gd),
173+
*(this->hR->get_paraV()),
174+
GlobalC::ld.pdm);
164175

165176
std::vector<torch::Tensor> descriptor;
166177
DeePKS_domain::cal_descriptor(this->ucell->nat,

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ if(ENABLE_DEEPKS)
1111
deepks_vdelta.cpp
1212
deepks_vdpre.cpp
1313
deepks_hmat.cpp
14+
deepks_pdm.cpp
1415
LCAO_deepks_io.cpp
15-
LCAO_deepks_pdm.cpp
1616
LCAO_deepks_phialpha.cpp
1717
LCAO_deepks_interface.cpp
1818
)

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "deepks_hmat.h"
1111
#include "deepks_orbital.h"
1212
#include "deepks_orbpre.h"
13+
#include "deepks_pdm.h"
1314
#include "deepks_spre.h"
1415
#include "deepks_vdelta.h"
1516
#include "deepks_vdpre.h"
@@ -20,13 +21,11 @@
2021
#include "module_basis/module_ao/parallel_orbitals.h"
2122
#include "module_basis/module_nao/two_center_integrator.h"
2223
#include "module_cell/module_neighbor/sltk_grid_driver.h"
23-
#include "module_elecstate/module_dm/density_matrix.h"
2424
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
2525
#include "module_io/winput.h"
2626

2727
#include <torch/script.h>
2828
#include <torch/torch.h>
29-
#include <unordered_map>
3029

3130
///
3231
/// The LCAO_Deepks contains subroutines for implementation of the DeePKS method in atomic basis.
@@ -102,7 +101,7 @@ class LCAO_Deepks
102101
ModuleBase::IntArray* alpha_index; // seems not used in the code
103102
ModuleBase::IntArray* inl_index; // caoyu add 2021-05-07
104103

105-
bool init_pdm = false; // for DeePKS NSCF calculation
104+
bool init_pdm = false; // for DeePKS NSCF calculation, set init_pdm to skip the calculation of pdm in SCF iteration
106105

107106
// deep neural network module that provides corrected Hamiltonian term and
108107
// related derivatives. Used in cal_gedm.
@@ -194,52 +193,6 @@ class LCAO_Deepks
194193
const LCAO_Orbitals& orb,
195194
const Grid_Driver& GridD);
196195

197-
//-------------------
198-
// LCAO_deepks_pdm.cpp
199-
//-------------------
200-
201-
// This file contains subroutines for calculating pdm,
202-
// which is defind as sum_mu,nu rho_mu,nu (<chi_mu|alpha><alpha|chi_nu>);
203-
// as well as gdmx, which is the gradient of pdm, defined as
204-
// sum_mu,nu rho_mu,nu d/dX(<chi_mu|alpha><alpha|chi_nu>)
205-
206-
// It also contains subroutines for printing pdm and gdmx
207-
// for checking purpose
208-
209-
// There are 2 subroutines in this file:
210-
// 1. cal_projected_DM, which is used for calculating pdm
211-
// 2. check_projected_dm, which prints pdm to descriptor.dat
212-
213-
public:
214-
/**
215-
* @brief calculate projected density matrix:
216-
* pdm = sum_i,occ <phi_i|alpha1><alpha2|phi_k>
217-
* 3 cases to skip calculation of pdm:
218-
* 1. NSCF calculation of DeePKS, init_chg = file and pdm has been read
219-
* 2. SCF calculation of DeePKS with init_chg = file and pdm has been read for restarting SCF
220-
* 3. Relax/Cell-Relax/MD calculation, non-first step will use the convergence pdm from the last step as initial
221-
* pdm
222-
*/
223-
template <typename TK>
224-
void cal_projected_DM(const elecstate::DensityMatrix<TK, double>* dm,
225-
const UnitCell& ucell,
226-
const LCAO_Orbitals& orb,
227-
const Grid_Driver& GridD);
228-
229-
void check_projected_dm();
230-
231-
/**
232-
* @brief set init_pdm to skip the calculation of pdm in SCF iteration
233-
*/
234-
void set_init_pdm(bool ipdm)
235-
{
236-
this->init_pdm = ipdm;
237-
}
238-
/**
239-
* @brief read pdm from file, do it only once in whole calculation
240-
*/
241-
void read_projected_DM(bool read_pdm_file, bool is_equiv, const Numerical_Orbital& alpha);
242-
243196
public:
244197
//! a temporary interface for cal_e_delta_band
245198
template <typename TK>

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
4040
const int des_per_atom = ld->des_per_atom;
4141
const int* inl_l = ld->inl_l;
4242
const ModuleBase::IntArray* inl_index = ld->inl_index;
43-
const std::vector<torch::Tensor> pdm = ld->pdm;
4443
const std::vector<hamilt::HContainer<double>*> phialpha = ld->phialpha;
44+
std::vector<torch::Tensor> pdm = ld->pdm;
4545

4646
const int my_rank = GlobalV::MY_RANK;
4747
const int nspin = PARAM.inp.nspin;
@@ -348,10 +348,11 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
348348
// when deepks_scf is on, the init pdm should be same as the out pdm, so we should not recalculate the pdm
349349
if (!PARAM.inp.deepks_scf)
350350
{
351-
ld->cal_projected_DM<TK>(dm, ucell, orb, GridD);
351+
DeePKS_domain::cal_pdm<
352+
TK>(ld->init_pdm, inlmax, lmaxd, inl_l, inl_index, dm, phialpha, ucell, orb, GridD, *ParaV, pdm);
352353
}
353354

354-
ld->check_projected_dm(); // print out the projected dm for NSCF calculaiton
355+
DeePKS_domain::check_pdm(inlmax, inl_l, pdm); // print out the projected dm for NSCF calculaiton
355356

356357
std::vector<torch::Tensor> descriptor;
357358
DeePKS_domain::cal_descriptor(nat, inlmax, inl_l, pdm, descriptor,

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "module_base/complexmatrix.h"
77
#include "module_base/matrix.h"
88
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
9+
910
#include <memory>
1011

1112
template <typename TK, typename TR>
@@ -26,11 +27,9 @@ class LCAO_Deepks_Interface
2627
/// @param[in] orb
2728
/// @param[in] GridD
2829
/// @param[in] ParaV
29-
/// @param[in] psi
3030
/// @param[in] psid
31-
/// @param[in] dm_gamma
32-
/// @param[in] dm_k
33-
// for Gamma-only
31+
/// @param[in] dm
32+
/// @param[in] p_ham
3433
void out_deepks_labels(const double& etot,
3534
const int& nks,
3635
const int& nat,

0 commit comments

Comments
 (0)