Skip to content

Commit 56194a9

Browse files
committed
Remove the global dependence of functions related to phialpha in DeePKS.
1 parent 2576f32 commit 56194a9

File tree

9 files changed

+132
-82
lines changed

9 files changed

+132
-82
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ OBJS_DEEPKS=LCAO_deepks.o\
202202
deepks_vdpre.o\
203203
deepks_hmat.o\
204204
deepks_pdm.o\
205+
deepks_phialpha.o\
205206
LCAO_deepks_io.o\
206-
LCAO_deepks_phialpha.o\
207207
LCAO_deepks_interface.o\
208208

209209

source/module_esolver/lcao_before_scf.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,19 @@ void ESolver_KS_LCAO<TK, TR>::before_scf(UnitCell& ucell, const int istep)
211211
{
212212
const Parallel_Orbitals* pv = &this->pv;
213213
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
214-
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
214+
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
215215
// build and save <phi(0)|alpha(R)> at beginning
216-
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
216+
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
217+
ucell,
218+
orb_,
219+
this->gd,
220+
pv,
221+
*(two_center_bundle_.overlap_orb_alpha),
222+
GlobalC::ld.phialpha);
217223

218224
if (PARAM.inp.deepks_out_unittest)
219225
{
220-
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
226+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
221227
}
222228
}
223229
#endif

source/module_esolver/lcao_others.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,19 @@ void ESolver_KS_LCAO<TK, TR>::others(UnitCell& ucell, const int istep)
217217
{
218218
const Parallel_Orbitals* pv = &this->pv;
219219
// allocate <phi(0)|alpha(R)>, phialpha is different every ion step, so it is allocated here
220-
GlobalC::ld.allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
220+
DeePKS_domain::allocate_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
221221
// build and save <phi(0)|alpha(R)> at beginning
222-
GlobalC::ld.build_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, *(two_center_bundle_.overlap_orb_alpha));
222+
DeePKS_domain::build_phialpha(PARAM.inp.cal_force,
223+
ucell,
224+
orb_,
225+
this->gd,
226+
pv,
227+
*(two_center_bundle_.overlap_orb_alpha),
228+
GlobalC::ld.phialpha);
223229

224230
if (PARAM.inp.deepks_out_unittest)
225231
{
226-
GlobalC::ld.check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd);
232+
DeePKS_domain::check_phialpha(PARAM.inp.cal_force, ucell, orb_, this->gd, pv, GlobalC::ld.phialpha);
227233
}
228234
}
229235
#endif

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ if(ENABLE_DEEPKS)
1212
deepks_vdpre.cpp
1313
deepks_hmat.cpp
1414
deepks_pdm.cpp
15+
deepks_phialpha.cpp
1516
LCAO_deepks_io.cpp
16-
LCAO_deepks_phialpha.cpp
1717
LCAO_deepks_interface.cpp
1818
)
1919

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "deepks_orbital.h"
1212
#include "deepks_orbpre.h"
1313
#include "deepks_pdm.h"
14+
#include "deepks_phialpha.h"
1415
#include "deepks_spre.h"
1516
#include "deepks_vdelta.h"
1617
#include "deepks_vdpre.h"
@@ -158,47 +159,14 @@ class LCAO_Deepks
158159
/// Allocate memory for correction to Hamiltonian
159160
void allocate_V_delta(const int nat, const int nks = 1);
160161

161-
private:
162-
// arrange index of descriptor in all atoms
163-
void init_index(const int ntype, const int nat, std::vector<int> na, const int tot_inl, const LCAO_Orbitals& orb);
164-
165-
//-------------------
166-
// LCAO_deepks_phialpha.cpp
167-
//-------------------
168-
169-
// E.Wu 2024-12-24
170-
// This file contains 3 subroutines:
171-
// 1. allocate_phialpha, which allocates memory for phialpha
172-
// 2. build_phialpha, which calculates the overlap
173-
// between atomic basis and projector alpha : <phi_mu|alpha>
174-
// which will be used in calculating pdm, gdmx, H_V_delta, F_delta;
175-
// 3. check_phialpha, which prints the results into .dat files
176-
// for checking
177-
178-
public:
179-
// calculates <chi|alpha>
180-
void allocate_phialpha(const bool& cal_deri,
181-
const UnitCell& ucell,
182-
const LCAO_Orbitals& orb,
183-
const Grid_Driver& GridD);
184-
185-
void build_phialpha(const bool& cal_deri /**< [in] 0 for 2-center intergration, 1 for its derivation*/,
186-
const UnitCell& ucell,
187-
const LCAO_Orbitals& orb,
188-
const Grid_Driver& GridD,
189-
const TwoCenterIntegrator& overlap_orb_alpha);
190-
191-
void check_phialpha(const bool& cal_deri /**< [in] 0 for 2-center intergration, 1 for its derivation*/,
192-
const UnitCell& ucell,
193-
const LCAO_Orbitals& orb,
194-
const Grid_Driver& GridD);
195-
196-
public:
197162
//! a temporary interface for cal_e_delta_band
198163
template <typename TK>
199164
void dpks_cal_e_delta_band(const std::vector<std::vector<TK>>& dm, const int nks);
200165

201166
private:
167+
// arrange index of descriptor in all atoms
168+
void init_index(const int ntype, const int nat, std::vector<int> na, const int tot_inl, const LCAO_Orbitals& orb);
169+
202170
const Parallel_Orbitals* pv;
203171
};
204172

source/module_hamilt_lcao/module_deepks/LCAO_deepks_phialpha.cpp renamed to source/module_hamilt_lcao/module_deepks/deepks_phialpha.cpp

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,24 @@
88

99
#ifdef __DEEPKS
1010

11-
#include "LCAO_deepks.h"
11+
#include "deepks_phialpha.h"
12+
1213
#include "module_base/timer.h"
1314
#include "module_base/vector3.h"
1415
#include "module_parameter/parameter.h"
1516

16-
void LCAO_Deepks::allocate_phialpha(const bool& cal_deri,
17-
const UnitCell& ucell,
18-
const LCAO_Orbitals& orb,
19-
const Grid_Driver& GridD)
17+
void DeePKS_domain::allocate_phialpha(const bool& cal_deri,
18+
const UnitCell& ucell,
19+
const LCAO_Orbitals& orb,
20+
const Grid_Driver& GridD,
21+
const Parallel_Orbitals* pv,
22+
std::vector<hamilt::HContainer<double>*>& phialpha)
2023
{
21-
ModuleBase::TITLE("LCAO_Deepks", "allocate_phialpha");
24+
ModuleBase::TITLE("DeePKS_domain", "allocate_phialpha");
2225

23-
this->phialpha.resize(cal_deri ? 4 : 1);
26+
phialpha.resize(cal_deri ? 4 : 1);
2427

25-
this->phialpha[0] = new hamilt::HContainer<double>(pv); // phialpha is always real
28+
phialpha[0] = new hamilt::HContainer<double>(pv); // phialpha is always real
2629
// Do not use fix_gamma, since it may find wrong matrix for gamma-only case in DeePKS
2730

2831
// cutoff for alpha is same for all types of atoms
@@ -63,31 +66,33 @@ void LCAO_Deepks::allocate_phialpha(const bool& cal_deri,
6366
hamilt::AtomPair<double> pair(iat, ibt, R_index, pv);
6467
// Notice: in AtomPair, the usage is set_size(ncol, nrow)
6568
pair.set_size(nw_alpha, atom1->nw * PARAM.globalv.npol);
66-
this->phialpha[0]->insert_pair(pair);
69+
phialpha[0]->insert_pair(pair);
6770
}
6871
}
6972
}
7073

71-
this->phialpha[0]->allocate(nullptr, true);
74+
phialpha[0]->allocate(nullptr, true);
7275
// whether to calculate the derivative of phialpha
7376
if (cal_deri)
7477
{
7578
for (int i = 1; i < 4; ++i)
7679
{
77-
this->phialpha[i] = new hamilt::HContainer<double>(*this->phialpha[0], nullptr); // copy constructor
80+
phialpha[i] = new hamilt::HContainer<double>(*phialpha[0], nullptr); // copy constructor
7881
}
7982
}
8083
return;
8184
}
8285

83-
void LCAO_Deepks::build_phialpha(const bool& cal_deri,
84-
const UnitCell& ucell,
85-
const LCAO_Orbitals& orb,
86-
const Grid_Driver& GridD,
87-
const TwoCenterIntegrator& overlap_orb_alpha)
86+
void DeePKS_domain::build_phialpha(const bool& cal_deri,
87+
const UnitCell& ucell,
88+
const LCAO_Orbitals& orb,
89+
const Grid_Driver& GridD,
90+
const Parallel_Orbitals* pv,
91+
const TwoCenterIntegrator& overlap_orb_alpha,
92+
std::vector<hamilt::HContainer<double>*>& phialpha)
8893
{
89-
ModuleBase::TITLE("LCAO_Deepks", "build_phialpha");
90-
ModuleBase::timer::tick("LCAO_Deepks", "build_phialpha");
94+
ModuleBase::TITLE("DeePKS_domain", "build_phialpha");
95+
ModuleBase::timer::tick("DeePKS_domain", "build_phialpha");
9196

9297
// cutoff for alpha is same for all types of atoms
9398
const double Rcut_Alpha = orb.Alpha[0].getRcut();
@@ -126,13 +131,13 @@ void LCAO_Deepks::build_phialpha(const bool& cal_deri,
126131
continue;
127132
}
128133

129-
double* data_pointer = this->phialpha[0]->data(iat, ibt, R);
134+
double* data_pointer = phialpha[0]->data(iat, ibt, R);
130135
std::vector<double*> grad_pointer(3);
131136
if (cal_deri)
132137
{
133138
for (int i = 0; i < 3; ++i)
134139
{
135-
grad_pointer[i] = this->phialpha[i + 1]->data(iat, ibt, R);
140+
grad_pointer[i] = phialpha[i + 1]->data(iat, ibt, R);
136141
}
137142
}
138143

@@ -192,17 +197,19 @@ void LCAO_Deepks::build_phialpha(const bool& cal_deri,
192197
}
193198
}
194199

195-
ModuleBase::timer::tick("LCAO_Deepks", "build_phialpha");
200+
ModuleBase::timer::tick("DeePKS_domain", "build_phialpha");
196201
return;
197202
}
198203

199-
void LCAO_Deepks::check_phialpha(const bool& cal_deri,
200-
const UnitCell& ucell,
201-
const LCAO_Orbitals& orb,
202-
const Grid_Driver& GridD)
204+
void DeePKS_domain::check_phialpha(const bool& cal_deri,
205+
const UnitCell& ucell,
206+
const LCAO_Orbitals& orb,
207+
const Grid_Driver& GridD,
208+
const Parallel_Orbitals* pv,
209+
std::vector<hamilt::HContainer<double>*>& phialpha)
203210
{
204-
ModuleBase::TITLE("LCAO_Deepks", "check_phialpha");
205-
ModuleBase::timer::tick("LCAO_Deepks", "check_phialpha");
211+
ModuleBase::TITLE("DeePKS_domain", "check_phialpha");
212+
ModuleBase::timer::tick("DeePKS_domain", "check_phialpha");
206213

207214
const double Rcut_Alpha = orb.Alpha[0].getRcut();
208215
// same for all types of atoms
@@ -280,13 +287,13 @@ void LCAO_Deepks::check_phialpha(const bool& cal_deri,
280287
ofs_z << "R : " << R[0] << " " << R[1] << " " << R[2] << std::endl;
281288
}
282289

283-
const double* data_pointer = this->phialpha[0]->data(iat, ibt, R);
290+
const double* data_pointer = phialpha[0]->data(iat, ibt, R);
284291
std::vector<double*> grad_pointer(3, nullptr);
285292
if (cal_deri)
286293
{
287-
grad_pointer[0] = this->phialpha[1]->data(iat, ibt, R);
288-
grad_pointer[1] = this->phialpha[2]->data(iat, ibt, R);
289-
grad_pointer[2] = this->phialpha[3]->data(iat, ibt, R);
294+
grad_pointer[0] = phialpha[1]->data(iat, ibt, R);
295+
grad_pointer[1] = phialpha[2]->data(iat, ibt, R);
296+
grad_pointer[2] = phialpha[3]->data(iat, ibt, R);
290297
}
291298

292299
for (int iw1 = 0; iw1 < nw1_tot; ++iw1)
@@ -334,7 +341,7 @@ void LCAO_Deepks::check_phialpha(const bool& cal_deri,
334341
} // end I0
335342
} // end T0
336343

337-
ModuleBase::timer::tick("LCAO_Deepks", "check_phialpha");
344+
ModuleBase::timer::tick("DeePKS_domain", "check_phialpha");
338345
return;
339346
}
340347

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#ifndef DEEPKS_PHIALPHA_H
2+
#define DEEPKS_PHIALPHA_H
3+
4+
#ifdef __DEEPKS
5+
6+
#include "module_base/complexmatrix.h"
7+
#include "module_base/matrix.h"
8+
#include "module_base/timer.h"
9+
#include "module_basis/module_ao/parallel_orbitals.h"
10+
#include "module_basis/module_nao/two_center_integrator.h"
11+
#include "module_cell/module_neighbor/sltk_grid_driver.h"
12+
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
13+
14+
#include <torch/script.h>
15+
#include <torch/torch.h>
16+
17+
namespace DeePKS_domain
18+
{
19+
// This file contains 3 subroutines:
20+
// 1. allocate_phialpha, which allocates memory for phialpha
21+
// 2. build_phialpha, which calculates the overlap
22+
// between atomic basis and projector alpha : <phi_mu|alpha>
23+
// which will be used in calculating pdm, gdmx, H_V_delta, F_delta;
24+
// 3. check_phialpha, which prints the results into .dat files
25+
// for checking
26+
27+
// calculates <chi|alpha>
28+
void allocate_phialpha(const bool& cal_deri,
29+
const UnitCell& ucell,
30+
const LCAO_Orbitals& orb,
31+
const Grid_Driver& GridD,
32+
const Parallel_Orbitals* pv,
33+
std::vector<hamilt::HContainer<double>*>& phialpha);
34+
35+
void build_phialpha(const bool& cal_deri /**< [in] 0 for 2-center intergration, 1 for its derivation*/,
36+
const UnitCell& ucell,
37+
const LCAO_Orbitals& orb,
38+
const Grid_Driver& GridD,
39+
const Parallel_Orbitals* pv,
40+
const TwoCenterIntegrator& overlap_orb_alpha,
41+
std::vector<hamilt::HContainer<double>*>& phialpha);
42+
43+
void check_phialpha(const bool& cal_deri /**< [in] 0 for 2-center intergration, 1 for its derivation*/,
44+
const UnitCell& ucell,
45+
const LCAO_Orbitals& orb,
46+
const Grid_Driver& GridD,
47+
const Parallel_Orbitals* pv,
48+
std::vector<hamilt::HContainer<double>*>& phialpha);
49+
} // namespace DeePKS_domain
50+
51+
#endif
52+
#endif

source/module_hamilt_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,22 @@ void test_deepks::check_phialpha()
3636
}
3737
GlobalC::ld.init(ORB, ucell.nat, ucell.ntype, kv.nkstot, ParaO, na);
3838

39-
GlobalC::ld.allocate_phialpha(PARAM.input.cal_force, ucell, ORB, Test_Deepks::GridD);
40-
41-
GlobalC::ld.build_phialpha(PARAM.input.cal_force, ucell, ORB, Test_Deepks::GridD, overlap_orb_alpha_);
42-
43-
GlobalC::ld.check_phialpha(PARAM.input.cal_force, ucell, ORB, Test_Deepks::GridD);
39+
DeePKS_domain::allocate_phialpha(PARAM.input.cal_force,
40+
ucell,
41+
ORB,
42+
Test_Deepks::GridD,
43+
&ParaO,
44+
GlobalC::ld.phialpha);
45+
46+
DeePKS_domain::build_phialpha(PARAM.input.cal_force,
47+
ucell,
48+
ORB,
49+
Test_Deepks::GridD,
50+
&ParaO,
51+
overlap_orb_alpha_,
52+
GlobalC::ld.phialpha);
53+
54+
DeePKS_domain::check_phialpha(PARAM.input.cal_force, ucell, ORB, Test_Deepks::GridD, &ParaO, GlobalC::ld.phialpha);
4455

4556
this->compare_with_ref("phialpha.dat", "phialpha_ref.dat");
4657
this->compare_with_ref("dphialpha_x.dat", "dphialpha_x_ref.dat");

source/module_hamilt_lcao/module_deepks/test/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ deepks_basic.o\
1111
deepks_force.o\
1212
deepks_vdelta.o\
1313
deepks_pdm.o\
14+
deepks_phialpha.o\
1415
LCAO_deepks.o\
1516
LCAO_deepks_io.o\
16-
LCAO_deepks_phialpha.o\
1717

1818
OBJS_IO=output.o\
1919

0 commit comments

Comments
 (0)