|
| 1 | +// prepare_phialpha_r : prepare phialpha_r for outputting npy file |
| 2 | + |
| 3 | +#ifdef __DEEPKS |
| 4 | + |
| 5 | +#include "deepks_vdrpre.h" |
| 6 | + |
| 7 | +#include "LCAO_deepks_io.h" // mohan add 2024-07-22 |
| 8 | +#include "deepks_iterate.h" |
| 9 | +#include "module_base/blas_connector.h" |
| 10 | +#include "module_base/constants.h" |
| 11 | +#include "module_base/libm/libm.h" |
| 12 | +#include "module_base/parallel_reduce.h" |
| 13 | +#include "module_hamilt_lcao/module_hcontainer/atom_pair.h" |
| 14 | +#include "module_parameter/parameter.h" |
| 15 | + |
| 16 | +void DeePKS_domain::prepare_phialpha_r(const int nlocal, |
| 17 | + const int lmaxd, |
| 18 | + const int inlmax, |
| 19 | + const int nat, |
| 20 | + const std::vector<hamilt::HContainer<double>*> phialpha, |
| 21 | + const UnitCell& ucell, |
| 22 | + const LCAO_Orbitals& orb, |
| 23 | + const Parallel_Orbitals& pv, |
| 24 | + const Grid_Driver& GridD, |
| 25 | + torch::Tensor& phialpha_r_out, |
| 26 | + torch::Tensor& R_query) |
| 27 | +{ |
| 28 | + ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_r"); |
| 29 | + ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r"); |
| 30 | + constexpr torch::Dtype dtype = torch::kFloat64; |
| 31 | + int nlmax = inlmax / nat; |
| 32 | + int mmax = 2 * lmaxd + 1; |
| 33 | + auto size_R = static_cast<long>(phialpha[0]->size_R_loop()); |
| 34 | + phialpha_r_out = torch::zeros({size_R, nat, nlmax, nlocal, mmax}, dtype); |
| 35 | + R_query = torch::zeros({size_R, 3}, torch::kInt32); |
| 36 | + auto accessor = phialpha_r_out.accessor<double, 5>(); |
| 37 | + auto R_accessor = R_query.accessor<int, 2>(); |
| 38 | + |
| 39 | + for (int iR = 0; iR < size_R; ++iR) |
| 40 | + { |
| 41 | + phialpha[0]->loop_R(iR, R_accessor[iR][0], R_accessor[iR][1], R_accessor[iR][2]); |
| 42 | + } |
| 43 | + |
| 44 | + DeePKS_domain::iterate_ad1( |
| 45 | + ucell, |
| 46 | + GridD, |
| 47 | + orb, |
| 48 | + false, // no trace_alpha |
| 49 | + [&](const int iat, |
| 50 | + const ModuleBase::Vector3<double>& tau0, |
| 51 | + const int ibt, |
| 52 | + const ModuleBase::Vector3<double>& tau, |
| 53 | + const int start, |
| 54 | + const int nw_tot, |
| 55 | + ModuleBase::Vector3<int> dR) |
| 56 | + { |
| 57 | + if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr) |
| 58 | + { |
| 59 | + return; // to next loop |
| 60 | + } |
| 61 | + |
| 62 | + // middle loop : all atomic basis on the adjacent atom ad |
| 63 | + for (int iw1 = 0; iw1 < nw_tot; ++iw1) |
| 64 | + { |
| 65 | + const int iw1_all = start + iw1; |
| 66 | + const int iw1_local = pv.global2local_row(iw1_all); |
| 67 | + const int iw2_local = pv.global2local_col(iw1_all); |
| 68 | + if (iw1_local < 0 || iw2_local < 0) |
| 69 | + { |
| 70 | + continue; |
| 71 | + } |
| 72 | + hamilt::BaseMatrix<double>* overlap = phialpha[0]->find_matrix(iat, ibt, dR); |
| 73 | + const int iR = phialpha[0]->find_R(dR); |
| 74 | + |
| 75 | + int ib = 0; |
| 76 | + int nl = 0; |
| 77 | + for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0) |
| 78 | + { |
| 79 | + for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0) |
| 80 | + { |
| 81 | + const int nm = 2 * L0 + 1; |
| 82 | + for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d |
| 83 | + { |
| 84 | + accessor[iR][iat][nl][iw1_all][m1] += overlap->get_value(iw1, ib + m1); |
| 85 | + } |
| 86 | + ib += nm; |
| 87 | + nl++; |
| 88 | + } |
| 89 | + } |
| 90 | + } // end iw |
| 91 | + } |
| 92 | + ); |
| 93 | + |
| 94 | +#ifdef __MPI |
| 95 | + int size = size_R * nat * nlmax * nlocal * mmax; |
| 96 | + double* data_ptr = phialpha_r_out.data_ptr<double>(); |
| 97 | + Parallel_Reduce::reduce_all(data_ptr, size); |
| 98 | + |
| 99 | +#endif |
| 100 | + |
| 101 | + ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r"); |
| 102 | + return; |
| 103 | +} |
| 104 | +#endif |
0 commit comments