Skip to content

Commit 922b776

Browse files
committed
Change the output of HR precalc.
1 parent ced3357 commit 922b776

File tree

3 files changed

+103
-100
lines changed

3 files changed

+103
-100
lines changed

source/source_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -475,25 +475,19 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
475475
}
476476
else if (PARAM.inp.deepks_v_delta == -2)
477477
{
478-
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
479-
torch::Tensor phialpha_r_out;
480-
DeePKS_domain::prepare_phialpha_r(nlocal,
481-
nat,
482-
R_size,
483-
deepks_param,
484-
phialpha,
485-
ucell,
486-
orb,
487-
*ParaV,
488-
GridD,
489-
phialpha_r_out);
490-
const std::string file_phialpha_r = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
491-
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
492-
493478
torch::Tensor gevdm_out;
494479
DeePKS_domain::prepare_gevdm(nat, deepks_param, orb, gevdm, gevdm_out);
495480
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
496481
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
482+
483+
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
484+
torch::Tensor overlap_out;
485+
torch::Tensor iRmat;
486+
DeePKS_domain::prepare_phialpha_iRmat(nlocal, R_size, deepks_param, phialpha, ucell, orb, GridD, overlap_out, iRmat);
487+
const std::string file_overlap = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
488+
LCAO_deepks_io::save_tensor2npy<double>(file_overlap, overlap_out, rank);
489+
const std::string file_iRmat = PARAM.globalv.global_out_dir + "deepks_iRmat.npy";
490+
LCAO_deepks_io::save_tensor2npy<int>(file_iRmat, iRmat, rank);
497491
}
498492
}
499493
}

source/source_lcao/module_deepks/deepks_vdrpre.cpp

Lines changed: 83 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// prepare_phialpha_r : prepare phialpha_r for outputting npy file
1+
// prepare_phialpha_iRmat : prepare phialpha_r and iR_mat for outputting npy file
22

33
#ifdef __MLALGO
44

@@ -13,85 +13,84 @@
1313
#include "source_io/module_parameter/parameter.h"
1414
#include "source_lcao/module_hcontainer/atom_pair.h"
1515

16-
void DeePKS_domain::prepare_phialpha_r(const int nlocal,
17-
const int nat,
18-
const int R_size,
19-
const DeePKS_Param& deepks_param,
20-
const std::vector<hamilt::HContainer<double>*> phialpha,
21-
const UnitCell& ucell,
22-
const LCAO_Orbitals& orb,
23-
const Parallel_Orbitals& pv,
24-
const Grid_Driver& GridD,
25-
torch::Tensor& phialpha_r_out)
16+
void DeePKS_domain::prepare_phialpha_iRmat(const int nlocal,
17+
const int R_size,
18+
const DeePKS_Param& deepks_param,
19+
const std::vector<hamilt::HContainer<double>*> phialpha,
20+
const UnitCell& ucell,
21+
const LCAO_Orbitals& orb,
22+
const Grid_Driver& GridD,
23+
torch::Tensor& overlap,
24+
torch::Tensor& iRmat)
2625
{
27-
ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_r");
28-
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
26+
ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_iRmat");
27+
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_iRmat");
2928
constexpr torch::Dtype dtype = torch::kFloat64;
30-
int nlmax = deepks_param.inlmax / nat;
31-
int mmax = 2 * deepks_param.lmaxd + 1;
32-
33-
phialpha_r_out = torch::zeros({R_size, R_size, R_size, nat, nlmax, nlocal, mmax}, dtype);
34-
auto accessor = phialpha_r_out.accessor<double, 7>();
35-
36-
DeePKS_domain::iterate_ad1(ucell,
37-
GridD,
38-
orb,
39-
false, // no trace_alpha
40-
[&](const int iat,
41-
const ModuleBase::Vector3<double>& tau0,
42-
const int ibt,
43-
const ModuleBase::Vector3<double>& tau,
44-
const int start,
45-
const int nw_tot,
46-
ModuleBase::Vector3<int> dR) {
47-
if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr)
48-
{
49-
return; // to next loop
50-
}
5129

52-
// middle loop : all atomic basis on the adjacent atom ad
53-
for (int iw1 = 0; iw1 < nw_tot; ++iw1)
54-
{
55-
const int iw1_all = start + iw1;
56-
const int iw1_local = pv.global2local_row(iw1_all);
57-
const int iw2_local = pv.global2local_col(iw1_all);
58-
if (iw1_local < 0 || iw2_local < 0)
59-
{
60-
continue;
61-
}
62-
hamilt::BaseMatrix<double>* overlap = phialpha[0]->find_matrix(iat, ibt, dR);
63-
const int iR = phialpha[0]->find_R(dR);
64-
65-
int ib = 0;
66-
int nl = 0;
67-
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
68-
{
69-
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
70-
{
71-
const int nm = 2 * L0 + 1;
72-
for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
73-
{
74-
int iRx = DeePKS_domain::mapping_R(dR.x);
75-
int iRy = DeePKS_domain::mapping_R(dR.y);
76-
int iRz = DeePKS_domain::mapping_R(dR.z);
77-
accessor[iRx][iRy][iRz][iat][nl][iw1_all][m1]
78-
+= overlap->get_value(iw1, ib + m1);
79-
}
80-
ib += nm;
81-
nl++;
82-
}
83-
}
84-
} // end iw
85-
});
86-
87-
#ifdef __MPI
88-
int size = R_size * R_size * R_size * nat * nlmax * nlocal * mmax;
89-
double* data_ptr = phialpha_r_out.data_ptr<double>();
90-
Parallel_Reduce::reduce_all(data_ptr, size);
30+
// get the maximum nnmax
31+
std::vector<int> nnmax_vec(ucell.nat, 0);
32+
DeePKS_domain::iterate_ad1(
33+
ucell,
34+
GridD,
35+
orb,
36+
false, // no trace_alpha
37+
[&](const int iat,
38+
const ModuleBase::Vector3<double>& tau0,
39+
const int ibt,
40+
const ModuleBase::Vector3<double>& tau1,
41+
const int start,
42+
const int nw_tot,
43+
ModuleBase::Vector3<int> dR)
44+
{
45+
if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr)
46+
{
47+
return; // to next loop
48+
}
49+
nnmax_vec[iat]++;
50+
}
51+
);
52+
53+
int nnmax = *std::max_element(nnmax_vec.begin(), nnmax_vec.end());
54+
overlap = torch::zeros({ucell.nat, nnmax, nlocal, deepks_param.des_per_atom}, dtype);
55+
torch::Tensor dRmat_tmp = torch::zeros({ucell.nat, nnmax, 3}, torch::kInt32);
56+
auto overlap_accessor = overlap.accessor<double, 4>();
57+
auto dRmat_accessor = dRmat_tmp.accessor<int, 3>();
9158

92-
#endif
59+
std::fill(nnmax_vec.begin(), nnmax_vec.end(), 0);
60+
DeePKS_domain::iterate_ad1(
61+
ucell,
62+
GridD,
63+
orb,
64+
false, // no trace_alpha
65+
[&](const int iat,
66+
const ModuleBase::Vector3<double>& tau0,
67+
const int ibt,
68+
const ModuleBase::Vector3<double>& tau1,
69+
const int start,
70+
const int nw_tot,
71+
ModuleBase::Vector3<int> dR)
72+
{
73+
hamilt::BaseMatrix<double>* overlap_mat = phialpha[0]->find_matrix(iat, ibt, dR);
74+
if (overlap_mat == nullptr)
75+
{
76+
return; // to next loop
77+
}
78+
dRmat_accessor[iat][nnmax_vec[iat]][0] = dR.x;
79+
dRmat_accessor[iat][nnmax_vec[iat]][1] = dR.y;
80+
dRmat_accessor[iat][nnmax_vec[iat]][2] = dR.z;
9381

94-
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
82+
for (int ix = 0; ix < nw_tot; ix++)
83+
{
84+
for (int iy = 0; iy < deepks_param.des_per_atom; iy++)
85+
{
86+
overlap_accessor[iat][nnmax_vec[iat]][start + ix][iy] = overlap_mat->get_value(ix, iy);
87+
}
88+
}
89+
nnmax_vec[iat]++;
90+
}
91+
);
92+
iRmat = mapping_R(dRmat_tmp.unsqueeze(1) - dRmat_tmp.unsqueeze(2));
93+
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_iRmat");
9594
return;
9695
}
9796

@@ -253,6 +252,15 @@ int DeePKS_domain::mapping_R(int R)
253252
return R_index;
254253
}
255254

255+
torch::Tensor DeePKS_domain::mapping_R(const torch::Tensor& R_tensor)
256+
{
257+
auto R = R_tensor.to(torch::kInt32);
258+
auto pos = R > 0;
259+
auto twoR_minus1 = R * 2 - 1;
260+
auto neg_minus2R = -2 * R;
261+
return at::where(pos, twoR_minus1, neg_minus2R);
262+
}
263+
256264
template <typename T>
257265
int DeePKS_domain::get_R_size(const hamilt::HContainer<T>& hcontainer)
258266
{

source/source_lcao/module_deepks/deepks_vdrpre.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ namespace DeePKS_domain
2929

3030
// for deepks_v_delta = -1
3131
// calculates v_delta_r_precalc
32-
void prepare_phialpha_r(const int nlocal,
33-
const int nat,
34-
const int R_size,
35-
const DeePKS_Param& deepks_param,
36-
const std::vector<hamilt::HContainer<double>*> phialpha,
37-
const UnitCell& ucell,
38-
const LCAO_Orbitals& orb,
39-
const Parallel_Orbitals& pv,
40-
const Grid_Driver& GridD,
41-
torch::Tensor& phialpha_r_out);
32+
33+
void prepare_phialpha_iRmat(const int nlocal,
34+
const int R_size,
35+
const DeePKS_Param& deepks_param,
36+
const std::vector<hamilt::HContainer<double>*> phialpha,
37+
const UnitCell& ucell,
38+
const LCAO_Orbitals& orb,
39+
const Grid_Driver& GridD,
40+
torch::Tensor& overlap,
41+
torch::Tensor& iRmat);
4242

4343
void cal_vdr_precalc(const int nlocal,
4444
const int nat,
@@ -55,6 +55,7 @@ void cal_vdr_precalc(const int nlocal,
5555
torch::Tensor& vdr_precalc);
5656

5757
int mapping_R(int R);
58+
torch::Tensor mapping_R(const torch::Tensor& R_tensor);
5859

5960
template <typename T>
6061
int get_R_size(const hamilt::HContainer<T>& hcontainer);

0 commit comments

Comments
 (0)