Skip to content

Commit 327707b

Browse files
committed
Simplify pdm.
1 parent 27bd74d commit 327707b

File tree

6 files changed

+166
-203
lines changed

6 files changed

+166
-203
lines changed

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
124124
// this part is for integrated test of deepks
125125
// so it is printed no matter even if deepks_out_labels is not used
126126
DeePKS_domain::cal_pdm<
127-
TK>(init_pdm, inlmax, lmaxd, inl2l, inl_index, kvec_d, dmr, phialpha, ucell, orb, GridD, *ParaV, pdm);
127+
TK>(init_pdm, deepks_param, kvec_d, dmr, phialpha, ucell, orb, GridD, *ParaV, pdm);
128128

129-
DeePKS_domain::check_pdm(inlmax, inl2l, pdm); // print out the projected dm for NSCF calculaiton
129+
DeePKS_domain::check_pdm(deepks_param, pdm); // print out the projected dm for NSCF calculaiton
130130

131131
std::vector<torch::Tensor> descriptor;
132132
DeePKS_domain::cal_descriptor(nat, deepks_param, pdm, descriptor); // final descriptor
@@ -368,15 +368,12 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
368368
torch::Tensor orbital_precalc_temp;
369369
ModuleBase::matrix o_delta_temp(nks, 1);
370370
DeePKS_domain::cal_orbital_precalc<TK, TH>(dm_bandgap,
371-
lmaxd,
372-
inlmax,
373371
nat,
374372
nks,
375-
inl2l,
373+
deepks_param,
376374
kvec_d,
377375
phialpha,
378376
gevdm,
379-
inl_index,
380377
ucell,
381378
orb,
382379
*ParaV,

source/source_lcao/module_deepks/deepks_pdm.cpp

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
2929
bool is_equiv,
3030
bool& init_pdm,
3131
const int nat,
32-
const int inlmax,
33-
const int lmaxd,
34-
const std::vector<int>& inl2l,
32+
const DeePKS_Param& deepks_param,
3533
const Numerical_Orbital& alpha,
3634
std::vector<torch::Tensor>& pdm)
3735
{
@@ -46,9 +44,9 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
4644
}
4745
if (!is_equiv)
4846
{
49-
for (int inl = 0; inl < inlmax; inl++)
47+
for (int inl = 0; inl < deepks_param.inlmax; inl++)
5048
{
51-
int nm = 2 * inl2l[inl] + 1;
49+
int nm = 2 * deepks_param.inl2l[inl] + 1;
5250
auto accessor = pdm[inl].accessor<double, 2>();
5351
for (int m1 = 0; m1 < nm; m1++)
5452
{
@@ -65,7 +63,7 @@ void DeePKS_domain::read_pdm(bool read_pdm_file,
6563
{
6664
int pdm_size = 0;
6765
int nproj = 0;
68-
for (int il = 0; il < lmaxd + 1; il++)
66+
for (int il = 0; il < deepks_param.lmaxd + 1; il++)
6967
{
7068
nproj += (2 * il + 1) * alpha.getNchi(il);
7169
}
@@ -177,10 +175,7 @@ void DeePKS_domain::update_dmr(const std::vector<ModuleBase::Vector3<double>>& k
177175
// pdm_m,m'=\sum_{mu,nu} rho_{mu,nu} <chi_mu|alpha_m><alpha_m'|chi_nu>
178176
template <typename TK>
179177
void DeePKS_domain::cal_pdm(bool& init_pdm,
180-
const int inlmax,
181-
const int lmaxd,
182-
const std::vector<int>& inl2l,
183-
const ModuleBase::IntArray* inl_index,
178+
const DeePKS_Param& deepks_param,
184179
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
185180
const hamilt::HContainer<double>* dmr,
186181
const std::vector<hamilt::HContainer<double>*> phialpha,
@@ -203,17 +198,17 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
203198

204199
if (!PARAM.inp.deepks_equiv)
205200
{
206-
for (int inl = 0; inl < inlmax; inl++)
201+
for (int inl = 0; inl < deepks_param.inlmax; inl++)
207202
{
208-
int nm = 2 * inl2l[inl] + 1;
203+
int nm = 2 * deepks_param.inl2l[inl] + 1;
209204
pdm[inl] = torch::zeros({nm, nm}, torch::kFloat64);
210205
}
211206
}
212207
else
213208
{
214209
int pdm_size = 0;
215210
int nproj = 0;
216-
for (int il = 0; il < lmaxd + 1; il++)
211+
for (int il = 0; il < deepks_param.lmaxd + 1; il++)
217212
{
218213
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
219214
}
@@ -246,7 +241,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
246241
{
247242
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
248243
{
249-
const int inl = inl_index[T0](I0, L0, N0);
244+
const int inl = deepks_param.inl_index[T0](I0, L0, N0);
250245
const int nm = 2 * L0 + 1;
251246

252247
for (int m1 = 0; m1 < nm; ++m1) // m1 = 1 for s, 3 for p, 5 for d
@@ -264,7 +259,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
264259
else
265260
{
266261
int nproj = 0;
267-
for (int il = 0; il < lmaxd + 1; il++)
262+
for (int il = 0; il < deepks_param.lmaxd + 1; il++)
268263
{
269264
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
270265
}
@@ -399,7 +394,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
399394
{
400395
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
401396
{
402-
const int inl = inl_index[T0](I0, L0, N0);
397+
const int inl = deepks_param.inl_index[T0](I0, L0, N0);
403398
const int nm = 2 * L0 + 1;
404399

405400
auto accessor = pdm[inl].accessor<double, 2>();
@@ -423,7 +418,7 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
423418
auto accessor = pdm[iat].accessor<double, 1>();
424419
int index = 0, inc = 1;
425420
int nproj = 0;
426-
for (int il = 0; il < lmaxd + 1; il++)
421+
for (int il = 0; il < deepks_param.lmaxd + 1; il++)
427422
{
428423
nproj += (2 * il + 1) * orb.Alpha[0].getNchi(il);
429424
}
@@ -446,25 +441,25 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
446441
} // iat
447442

448443
#ifdef __MPI
449-
for (int inl = 0; inl < inlmax; inl++)
444+
for (int inl = 0; inl < deepks_param.inlmax; inl++)
450445
{
451-
int pdm_size = (2 * inl2l[inl] + 1) * (2 * inl2l[inl] + 1);
446+
int pdm_size = (2 * deepks_param.inl2l[inl] + 1) * (2 * deepks_param.inl2l[inl] + 1);
452447
Parallel_Reduce::reduce_all(pdm[inl].data_ptr<double>(), pdm_size);
453448
}
454449
#endif
455450
ModuleBase::timer::tick("DeePKS_domain", "cal_pdm");
456451
return;
457452
}
458453

459-
void DeePKS_domain::check_pdm(const int inlmax, const std::vector<int>& inl2l, const std::vector<torch::Tensor>& pdm)
454+
void DeePKS_domain::check_pdm(const DeePKS_Param& deepks_param, const std::vector<torch::Tensor>& pdm)
460455
{
461456
const std::string file_projdm = PARAM.globalv.global_out_dir + "deepks_projdm.dat";
462457
std::ofstream ofs(file_projdm.c_str());
463458

464459
ofs << std::setprecision(10);
465-
for (int inl = 0; inl < inlmax; inl++)
460+
for (int inl = 0; inl < deepks_param.inlmax; inl++)
466461
{
467-
const int nm = 2 * inl2l[inl] + 1;
462+
const int nm = 2 * deepks_param.inl2l[inl] + 1;
468463
auto accessor = pdm[inl].accessor<double, 2>();
469464
for (int m1 = 0; m1 < nm; m1++)
470465
{
@@ -494,10 +489,7 @@ template void DeePKS_domain::update_dmr<std::complex<double>>(const std::vector<
494489
hamilt::HContainer<double>* dmr_deepks);
495490

496491
template void DeePKS_domain::cal_pdm<double>(bool& init_pdm,
497-
const int inlmax,
498-
const int lmaxd,
499-
const std::vector<int>& inl2l,
500-
const ModuleBase::IntArray* inl_index,
492+
const DeePKS_Param& deepks_param,
501493
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
502494
const hamilt::HContainer<double>* dmr,
503495
const std::vector<hamilt::HContainer<double>*> phialpha,
@@ -508,10 +500,7 @@ template void DeePKS_domain::cal_pdm<double>(bool& init_pdm,
508500
std::vector<torch::Tensor>& pdm);
509501

510502
template void DeePKS_domain::cal_pdm<std::complex<double>>(bool& init_pdm,
511-
const int inlmax,
512-
const int lmaxd,
513-
const std::vector<int>& inl2l,
514-
const ModuleBase::IntArray* inl_index,
503+
const DeePKS_Param& deepks_param,
515504
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
516505
const hamilt::HContainer<double>* dmr,
517506
const std::vector<hamilt::HContainer<double>*> phialpha,

source/source_lcao/module_deepks/deepks_pdm.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#ifdef __MLALGO
55

6+
#include "deepks_param.h"
67
#include "source_base/complexmatrix.h"
78
#include "source_base/matrix.h"
89
#include "source_base/timer.h"
@@ -33,9 +34,7 @@ void read_pdm(bool read_pdm_file,
3334
bool is_equiv,
3435
bool& init_pdm,
3536
const int nat,
36-
const int inlmax,
37-
const int lmaxd,
38-
const std::vector<int>& inl2l,
37+
const DeePKS_Param& deepks_param,
3938
const Numerical_Orbital& alpha,
4039
std::vector<torch::Tensor>& pdm);
4140

@@ -55,10 +54,7 @@ void update_dmr(const std::vector<ModuleBase::Vector3<double>>& kvec_d,
5554
// - Relax/Cell-Relax/MD calculation, non-first step will use the convergence pdm from the last step as initial pdm
5655
template <typename TK>
5756
void cal_pdm(bool& init_pdm,
58-
const int inlmax,
59-
const int lmaxd,
60-
const std::vector<int>& inl2l,
61-
const ModuleBase::IntArray* inl_index,
57+
const DeePKS_Param& deepks_param,
6258
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
6359
const hamilt::HContainer<double>* dmr,
6460
const std::vector<hamilt::HContainer<double>*> phialpha,
@@ -68,7 +64,7 @@ void cal_pdm(bool& init_pdm,
6864
const Parallel_Orbitals& pv,
6965
std::vector<torch::Tensor>& pdm);
7066

71-
void check_pdm(const int inlmax, const std::vector<int>& inl2l, const std::vector<torch::Tensor>& pdm);
67+
void check_pdm(const DeePKS_Param& deepks_param, const std::vector<torch::Tensor>& pdm);
7268
} // namespace DeePKS_domain
7369

7470
#endif

source/source_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,7 @@ void test_deepks<T>::check_pdm()
146146
Test_Deepks::GridD,
147147
this->ld.dm_r);
148148
DeePKS_domain::cal_pdm<T>(this->ld.init_pdm,
149-
this->ld.inlmax,
150-
this->ld.lmaxd,
151-
this->ld.inl2l,
152-
this->ld.inl_index,
149+
this->ld.deepks_param,
153150
kv.kvec_d,
154151
this->ld.dm_r,
155152
this->ld.phialpha,
@@ -158,7 +155,7 @@ void test_deepks<T>::check_pdm()
158155
Test_Deepks::GridD,
159156
ParaO,
160157
this->ld.pdm);
161-
DeePKS_domain::check_pdm(this->ld.inlmax, this->ld.inl2l, this->ld.pdm);
158+
DeePKS_domain::check_pdm(this->ld.deepks_param, this->ld.pdm);
162159
this->compare_with_ref("deepks_projdm.dat", "pdm_ref.dat");
163160
}
164161

@@ -246,15 +243,12 @@ void test_deepks<T>::check_orbpre()
246243
torch::Tensor orbpre;
247244
DeePKS_domain::cal_gevdm(ucell.nat, this->ld.deepks_param, this->ld.pdm, gevdm);
248245
DeePKS_domain::cal_orbital_precalc<T, TH>(dm,
249-
this->ld.lmaxd,
250-
this->ld.inlmax,
251246
ucell.nat,
252247
kv.nkstot,
253-
this->ld.inl2l,
248+
this->ld.deepks_param,
254249
kv.kvec_d,
255250
this->ld.phialpha,
256251
gevdm,
257-
this->ld.inl_index,
258252
ucell,
259253
ORB,
260254
ParaO,

source/source_lcao/module_operator_lcao/deepks_lcao.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,8 @@ void hamilt::DeePKS<hamilt::OperatorLCAO<TK, TR>>::contributeHR()
152152
{
153153
ModuleBase::timer::tick("DeePKS", "contributeHR");
154154

155-
const int inlmax = ptr_orb_->Alpha[0].getTotal_nchi() * this->ucell->nat;
156-
157155
DeePKS_domain::cal_pdm<TK>(this->ld->init_pdm,
158-
inlmax,
159-
this->ld->deepks_param.lmaxd,
160-
this->ld->deepks_param.inl2l,
161-
this->ld->deepks_param.inl_index,
156+
this->ld->deepks_param,
162157
this->kvec_d,
163158
this->ld->dm_r,
164159
this->ld->phialpha,

0 commit comments

Comments
 (0)