Skip to content

Commit f3a5dab

Browse files
committed
Simplify vdpre and vdrpre.
1 parent 77c253d commit f3a5dab

File tree

6 files changed

+42
-80
lines changed

6 files changed

+42
-80
lines changed

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -479,16 +479,13 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
479479
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
480480
torch::Tensor vdr_precalc;
481481
DeePKS_domain::cal_vdr_precalc(nlocal,
482-
lmaxd,
483-
inlmax,
484482
nat,
485483
nks,
486484
R_size,
487-
inl2l,
485+
deepks_param,
488486
kvec_d,
489487
phialpha,
490488
gevdm,
491-
inl_index,
492489
ucell,
493490
orb,
494491
*ParaV,
@@ -503,10 +500,9 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
503500
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
504501
torch::Tensor phialpha_r_out;
505502
DeePKS_domain::prepare_phialpha_r(nlocal,
506-
lmaxd,
507-
inlmax,
508503
nat,
509504
R_size,
505+
deepks_param,
510506
phialpha,
511507
ucell,
512508
orb,
@@ -517,7 +513,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
517513
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
518514

519515
torch::Tensor gevdm_out;
520-
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
516+
DeePKS_domain::prepare_gevdm(nat, deepks_param, orb, gevdm, gevdm_out);
521517
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
522518
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
523519
}
@@ -575,15 +571,12 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
575571
{
576572
torch::Tensor v_delta_precalc;
577573
DeePKS_domain::cal_v_delta_precalc<TK>(nlocal,
578-
lmaxd,
579-
inlmax,
580574
nat,
581575
nks,
582-
inl2l,
576+
deepks_param,
583577
kvec_d,
584578
phialpha,
585579
gevdm,
586-
inl_index,
587580
ucell,
588581
orb,
589582
*ParaV,
@@ -597,10 +590,9 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
597590
{
598591
torch::Tensor phialpha_out;
599592
DeePKS_domain::prepare_phialpha<TK>(nlocal,
600-
lmaxd,
601-
inlmax,
602593
nat,
603594
nks,
595+
deepks_param,
604596
kvec_d,
605597
phialpha,
606598
ucell,
@@ -612,7 +604,7 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
612604
LCAO_deepks_io::save_tensor2npy<TK>(file_phialpha, phialpha_out, rank);
613605

614606
torch::Tensor gevdm_out;
615-
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
607+
DeePKS_domain::prepare_gevdm(nat, deepks_param, orb, gevdm, gevdm_out);
616608
const std::string file_gevdm = get_filename("gevdm", PARAM.inp.deepks_out_labels, iter);
617609
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
618610
}

source/source_lcao/module_deepks/deepks_vdpre.cpp

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,12 @@
2222
// for deepks_v_delta = 1
2323
template <typename TK>
2424
void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
25-
const int lmaxd,
26-
const int inlmax,
2725
const int nat,
2826
const int nks,
29-
const std::vector<int>& inl2l,
27+
const DeePKS_Param& deepks_param,
3028
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
3129
const std::vector<hamilt::HContainer<double>*> phialpha,
3230
const std::vector<torch::Tensor> gevdm,
33-
const ModuleBase::IntArray* inl_index,
3431
const UnitCell& ucell,
3532
const LCAO_Orbitals& orb,
3633
const Parallel_Orbitals& pv,
@@ -47,7 +44,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
4744
typename std::conditional<std::is_same<TK, std::complex<double>>::value, c10::complex<double>, TK>::type;
4845

4946
torch::Tensor v_delta_pdm
50-
= torch::zeros({nks, nlocal, nlocal, inlmax, (2 * lmaxd + 1), (2 * lmaxd + 1)}, torch::dtype(dtype));
47+
= torch::zeros({nks, nlocal, nlocal, deepks_param.inlmax, (2 * deepks_param.lmaxd + 1), (2 * deepks_param.lmaxd + 1)}, torch::dtype(dtype));
5148
auto accessor = v_delta_pdm.accessor<TK_tensor, 6>();
5249

5350
DeePKS_domain::iterate_ad2(
@@ -112,7 +109,7 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
112109
{
113110
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
114111
{
115-
const int inl = inl_index[T0](I0, L0, N0);
112+
const int inl = deepks_param.inl_index[T0](I0, L0, N0);
116113
const int nm = 2 * L0 + 1;
117114
for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
118115
{
@@ -133,20 +130,20 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
133130
}
134131
);
135132
#ifdef __MPI
136-
const int size = nks * nlocal * nlocal * inlmax * (2 * lmaxd + 1) * (2 * lmaxd + 1);
133+
const int size = nks * nlocal * nlocal * deepks_param.inlmax * (2 * deepks_param.lmaxd + 1) * (2 * deepks_param.lmaxd + 1);
137134
TK_tensor* data_tensor_ptr = v_delta_pdm.data_ptr<TK_tensor>();
138135
TK* data_ptr = reinterpret_cast<TK*>(data_tensor_ptr);
139136
Parallel_Reduce::reduce_all(data_ptr, size);
140137
#endif
141138

142139
// transfer v_delta_pdm to v_delta_pdm_vector
143-
int nlmax = inlmax / nat;
140+
int nlmax = deepks_param.inlmax / nat;
144141
std::vector<torch::Tensor> v_delta_pdm_vector;
145142
for (int nl = 0; nl < nlmax; ++nl)
146143
{
147-
int nm = 2 * inl2l[nl] + 1;
144+
int nm = 2 * deepks_param.inl2l[nl] + 1;
148145
torch::Tensor v_delta_pdm_sliced
149-
= v_delta_pdm.slice(3, nl, inlmax, nlmax).slice(4, 0, nm, 1).slice(5, 0, nm, 1);
146+
= v_delta_pdm.slice(3, nl, deepks_param.inlmax, nlmax).slice(4, 0, nm, 1).slice(5, 0, nm, 1);
150147
v_delta_pdm_vector.push_back(v_delta_pdm_sliced);
151148
}
152149

@@ -173,10 +170,9 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
173170
// prepare_phialpha and prepare_gevdm for deepks_v_delta = 2
174171
template <typename TK>
175172
void DeePKS_domain::prepare_phialpha(const int nlocal,
176-
const int lmaxd,
177-
const int inlmax,
178173
const int nat,
179174
const int nks,
175+
const DeePKS_Param& deepks_param,
180176
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
181177
const std::vector<hamilt::HContainer<double>*> phialpha,
182178
const UnitCell& ucell,
@@ -190,8 +186,8 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
190186
constexpr torch::Dtype dtype = std::is_same<TK, double>::value ? torch::kFloat64 : torch::kComplexDouble;
191187
using TK_tensor =
192188
typename std::conditional<std::is_same<TK, std::complex<double>>::value, c10::complex<double>, TK>::type;
193-
int nlmax = inlmax / nat;
194-
int mmax = 2 * lmaxd + 1;
189+
int nlmax = deepks_param.inlmax / nat;
190+
int mmax = 2 * deepks_param.lmaxd + 1;
195191
phialpha_out = torch::zeros({nat, nlmax, nks, nlocal, mmax}, dtype);
196192
auto accessor = phialpha_out.accessor<TK_tensor, 5>();
197193

@@ -268,16 +264,15 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
268264
}
269265

270266
void DeePKS_domain::prepare_gevdm(const int nat,
271-
const int lmaxd,
272-
const int inlmax,
267+
const DeePKS_Param& deepks_param,
273268
const LCAO_Orbitals& orb,
274269
const std::vector<torch::Tensor>& gevdm_in,
275270
torch::Tensor& gevdm_out)
276271
{
277272
ModuleBase::TITLE("DeePKS_domain", "prepare_gevdm");
278273
ModuleBase::timer::tick("DeePKS_domain", "prepare_gevdm");
279-
int nlmax = inlmax / nat;
280-
int mmax = 2 * lmaxd + 1;
274+
int nlmax = deepks_param.inlmax / nat;
275+
int mmax = 2 * deepks_param.lmaxd + 1;
281276
gevdm_out = torch::zeros({nat, nlmax, mmax, mmax, mmax}, torch::TensorOptions().dtype(torch::kFloat64));
282277

283278
std::vector<torch::Tensor> gevdm_out_vector;
@@ -295,42 +290,35 @@ void DeePKS_domain::prepare_gevdm(const int nat,
295290
}
296291

297292
template void DeePKS_domain::cal_v_delta_precalc<double>(const int nlocal,
298-
const int lmaxd,
299-
const int inlmax,
300293
const int nat,
301294
const int nks,
302-
const std::vector<int>& inl2l,
295+
const DeePKS_Param& deepks_param,
303296
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
304297
const std::vector<hamilt::HContainer<double>*> phialpha,
305298
const std::vector<torch::Tensor> gevdm,
306-
const ModuleBase::IntArray* inl_index,
307299
const UnitCell& ucell,
308300
const LCAO_Orbitals& orb,
309301
const Parallel_Orbitals& pv,
310302
const Grid_Driver& GridD,
311303
torch::Tensor& v_delta_precalc);
312304
template void DeePKS_domain::cal_v_delta_precalc<std::complex<double>>(
313305
const int nlocal,
314-
const int lmaxd,
315-
const int inlmax,
316306
const int nat,
317307
const int nks,
318-
const std::vector<int>& inl2l,
308+
const DeePKS_Param& deepks_param,
319309
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
320310
const std::vector<hamilt::HContainer<double>*> phialpha,
321311
const std::vector<torch::Tensor> gevdm,
322-
const ModuleBase::IntArray* inl_index,
323312
const UnitCell& ucell,
324313
const LCAO_Orbitals& orb,
325314
const Parallel_Orbitals& pv,
326315
const Grid_Driver& GridD,
327316
torch::Tensor& v_delta_precalc);
328317

329318
template void DeePKS_domain::prepare_phialpha<double>(const int nlocal,
330-
const int lmaxd,
331-
const int inlmax,
332319
const int nat,
333320
const int nks,
321+
const DeePKS_Param& deepks_param,
334322
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
335323
const std::vector<hamilt::HContainer<double>*> phialpha,
336324
const UnitCell& ucell,
@@ -341,10 +329,9 @@ template void DeePKS_domain::prepare_phialpha<double>(const int nlocal,
341329

342330
template void DeePKS_domain::prepare_phialpha<std::complex<double>>(
343331
const int nlocal,
344-
const int lmaxd,
345-
const int inlmax,
346332
const int nat,
347333
const int nks,
334+
const DeePKS_Param& deepks_param,
348335
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
349336
const std::vector<hamilt::HContainer<double>*> phialpha,
350337
const UnitCell& ucell,

source/source_lcao/module_deepks/deepks_vdpre.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#ifdef __MLALGO
55

6+
#include "deepks_param.h"
67
#include "source_base/complexmatrix.h"
78
#include "source_base/intarray.h"
89
#include "source_base/matrix.h"
@@ -32,15 +33,12 @@ namespace DeePKS_domain
3233
// calculates v_delta_precalc
3334
template <typename TK>
3435
void cal_v_delta_precalc(const int nlocal,
35-
const int lmaxd,
36-
const int inlmax,
3736
const int nat,
3837
const int nks,
39-
const std::vector<int>& inl2l,
38+
const DeePKS_Param& deepks_param,
4039
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
4140
const std::vector<hamilt::HContainer<double>*> phialpha,
4241
const std::vector<torch::Tensor> gevdm,
43-
const ModuleBase::IntArray* inl_index,
4442
const UnitCell& ucell,
4543
const LCAO_Orbitals& orb,
4644
const Parallel_Orbitals& pv,
@@ -51,10 +49,9 @@ void cal_v_delta_precalc(const int nlocal,
5149
// prepare phialpha for outputting npy file
5250
template <typename TK>
5351
void prepare_phialpha(const int nlocal,
54-
const int lmaxd,
55-
const int inlmax,
5652
const int nat,
5753
const int nks,
54+
const DeePKS_Param& deepks_param,
5855
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
5956
const std::vector<hamilt::HContainer<double>*> phialpha,
6057
const UnitCell& ucell,
@@ -65,8 +62,7 @@ void prepare_phialpha(const int nlocal,
6562

6663
// prepare gevdm for outputting npy file
6764
void prepare_gevdm(const int nat,
68-
const int lmaxd,
69-
const int inlmax,
65+
const DeePKS_Param& deepks_param,
7066
const LCAO_Orbitals& orb,
7167
const std::vector<torch::Tensor>& gevdm_in,
7268
torch::Tensor& gevdm_out);

0 commit comments

Comments
 (0)