|
1 | | -// prepare_phialpha_r : prepare phialpha_r for outputting npy file |
| 1 | +// prepare_phialpha_iRmat : prepare phialpha_r and iR_mat for outputting npy file |
2 | 2 |
|
3 | 3 | #ifdef __MLALGO |
4 | 4 |
|
|
13 | 13 | #include "source_io/module_parameter/parameter.h" |
14 | 14 | #include "source_lcao/module_hcontainer/atom_pair.h" |
15 | 15 |
|
16 | | -void DeePKS_domain::prepare_phialpha_r(const int nlocal, |
17 | | - const int nat, |
18 | | - const int R_size, |
19 | | - const DeePKS_Param& deepks_param, |
20 | | - const std::vector<hamilt::HContainer<double>*> phialpha, |
21 | | - const UnitCell& ucell, |
22 | | - const LCAO_Orbitals& orb, |
23 | | - const Parallel_Orbitals& pv, |
24 | | - const Grid_Driver& GridD, |
25 | | - torch::Tensor& phialpha_r_out) |
| 16 | +void DeePKS_domain::prepare_phialpha_iRmat(const int nlocal, |
| 17 | + const int R_size, |
| 18 | + const DeePKS_Param& deepks_param, |
| 19 | + const std::vector<hamilt::HContainer<double>*> phialpha, |
| 20 | + const UnitCell& ucell, |
| 21 | + const LCAO_Orbitals& orb, |
| 22 | + const Grid_Driver& GridD, |
| 23 | + torch::Tensor& overlap, |
| 24 | + torch::Tensor& iRmat) |
26 | 25 | { |
27 | | - ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_r"); |
28 | | - ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r"); |
| 26 | + ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_iRmat"); |
| 27 | + ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_iRmat"); |
29 | 28 | constexpr torch::Dtype dtype = torch::kFloat64; |
30 | | - int nlmax = deepks_param.inlmax / nat; |
31 | | - int mmax = 2 * deepks_param.lmaxd + 1; |
32 | | - |
33 | | - phialpha_r_out = torch::zeros({R_size, R_size, R_size, nat, nlmax, nlocal, mmax}, dtype); |
34 | | - auto accessor = phialpha_r_out.accessor<double, 7>(); |
35 | | - |
36 | | - DeePKS_domain::iterate_ad1(ucell, |
37 | | - GridD, |
38 | | - orb, |
39 | | - false, // no trace_alpha |
40 | | - [&](const int iat, |
41 | | - const ModuleBase::Vector3<double>& tau0, |
42 | | - const int ibt, |
43 | | - const ModuleBase::Vector3<double>& tau, |
44 | | - const int start, |
45 | | - const int nw_tot, |
46 | | - ModuleBase::Vector3<int> dR) { |
47 | | - if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr) |
48 | | - { |
49 | | - return; // to next loop |
50 | | - } |
51 | 29 |
|
52 | | - // middle loop : all atomic basis on the adjacent atom ad |
53 | | - for (int iw1 = 0; iw1 < nw_tot; ++iw1) |
54 | | - { |
55 | | - const int iw1_all = start + iw1; |
56 | | - const int iw1_local = pv.global2local_row(iw1_all); |
57 | | - const int iw2_local = pv.global2local_col(iw1_all); |
58 | | - if (iw1_local < 0 || iw2_local < 0) |
59 | | - { |
60 | | - continue; |
61 | | - } |
62 | | - hamilt::BaseMatrix<double>* overlap = phialpha[0]->find_matrix(iat, ibt, dR); |
63 | | - const int iR = phialpha[0]->find_R(dR); |
64 | | - |
65 | | - int ib = 0; |
66 | | - int nl = 0; |
67 | | - for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) |
68 | | - { |
69 | | - for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0) |
70 | | - { |
71 | | - const int nm = 2 * L0 + 1; |
72 | | - for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d |
73 | | - { |
74 | | - int iRx = DeePKS_domain::mapping_R(dR.x); |
75 | | - int iRy = DeePKS_domain::mapping_R(dR.y); |
76 | | - int iRz = DeePKS_domain::mapping_R(dR.z); |
77 | | - accessor[iRx][iRy][iRz][iat][nl][iw1_all][m1] |
78 | | - += overlap->get_value(iw1, ib + m1); |
79 | | - } |
80 | | - ib += nm; |
81 | | - nl++; |
82 | | - } |
83 | | - } |
84 | | - } // end iw |
85 | | - }); |
86 | | - |
87 | | -#ifdef __MPI |
88 | | - int size = R_size * R_size * R_size * nat * nlmax * nlocal * mmax; |
89 | | - double* data_ptr = phialpha_r_out.data_ptr<double>(); |
90 | | - Parallel_Reduce::reduce_all(data_ptr, size); |
| 30 | + // get the maximum nnmax |
| 31 | + std::vector<int> nnmax_vec(ucell.nat, 0); |
| 32 | + DeePKS_domain::iterate_ad1( |
| 33 | + ucell, |
| 34 | + GridD, |
| 35 | + orb, |
| 36 | + false, // no trace_alpha |
| 37 | + [&](const int iat, |
| 38 | + const ModuleBase::Vector3<double>& tau0, |
| 39 | + const int ibt, |
| 40 | + const ModuleBase::Vector3<double>& tau1, |
| 41 | + const int start, |
| 42 | + const int nw_tot, |
| 43 | + ModuleBase::Vector3<int> dR) |
| 44 | + { |
| 45 | + if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr) |
| 46 | + { |
| 47 | + return; // to next loop |
| 48 | + } |
| 49 | + nnmax_vec[iat]++; |
| 50 | + } |
| 51 | + ); |
| 52 | + |
| 53 | + int nnmax = *std::max_element(nnmax_vec.begin(), nnmax_vec.end()); |
| 54 | + overlap = torch::zeros({ucell.nat, nnmax, nlocal, deepks_param.des_per_atom}, dtype); |
| 55 | + torch::Tensor dRmat_tmp = torch::zeros({ucell.nat, nnmax, 3}, torch::kInt32); |
| 56 | + auto overlap_accessor = overlap.accessor<double, 4>(); |
| 57 | + auto dRmat_accessor = dRmat_tmp.accessor<int, 3>(); |
91 | 58 |
|
92 | | -#endif |
| 59 | + std::fill(nnmax_vec.begin(), nnmax_vec.end(), 0); |
| 60 | + DeePKS_domain::iterate_ad1( |
| 61 | + ucell, |
| 62 | + GridD, |
| 63 | + orb, |
| 64 | + false, // no trace_alpha |
| 65 | + [&](const int iat, |
| 66 | + const ModuleBase::Vector3<double>& tau0, |
| 67 | + const int ibt, |
| 68 | + const ModuleBase::Vector3<double>& tau1, |
| 69 | + const int start, |
| 70 | + const int nw_tot, |
| 71 | + ModuleBase::Vector3<int> dR) |
| 72 | + { |
| 73 | + hamilt::BaseMatrix<double>* overlap_mat = phialpha[0]->find_matrix(iat, ibt, dR); |
| 74 | + if (overlap_mat == nullptr) |
| 75 | + { |
| 76 | + return; // to next loop |
| 77 | + } |
| 78 | + dRmat_accessor[iat][nnmax_vec[iat]][0] = dR.x; |
| 79 | + dRmat_accessor[iat][nnmax_vec[iat]][1] = dR.y; |
| 80 | + dRmat_accessor[iat][nnmax_vec[iat]][2] = dR.z; |
93 | 81 |
|
94 | | - ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r"); |
| 82 | + for (int ix = 0; ix < nw_tot; ix++) |
| 83 | + { |
| 84 | + for (int iy = 0; iy < deepks_param.des_per_atom; iy++) |
| 85 | + { |
| 86 | + overlap_accessor[iat][nnmax_vec[iat]][start + ix][iy] = overlap_mat->get_value(ix, iy); |
| 87 | + } |
| 88 | + } |
| 89 | + nnmax_vec[iat]++; |
| 90 | + } |
| 91 | + ); |
| 92 | + iRmat = mapping_R(dRmat_tmp.unsqueeze(1) - dRmat_tmp.unsqueeze(2)); |
| 93 | + ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_iRmat"); |
95 | 94 | return; |
96 | 95 | } |
97 | 96 |
|
@@ -253,6 +252,15 @@ int DeePKS_domain::mapping_R(int R) |
253 | 252 | return R_index; |
254 | 253 | } |
255 | 254 |
|
| 255 | +torch::Tensor DeePKS_domain::mapping_R(const torch::Tensor& R_tensor) |
| 256 | +{ |
| 257 | + auto R = R_tensor.to(torch::kInt32); |
| 258 | + auto pos = R > 0; |
| 259 | + auto twoR_minus1 = R * 2 - 1; |
| 260 | + auto neg_minus2R = -2 * R; |
| 261 | + return at::where(pos, twoR_minus1, neg_minus2R); |
| 262 | +} |
| 263 | + |
256 | 264 | template <typename T> |
257 | 265 | int DeePKS_domain::get_R_size(const hamilt::HContainer<T>& hcontainer) |
258 | 266 | { |
|
0 commit comments