Skip to content

Commit bce3cf2

Browse files
committed
Add HR precalc functions for DeePKS and fix some bugs.
1 parent 924d29d commit bce3cf2

File tree

8 files changed

+179
-37
lines changed

8 files changed

+179
-37
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ OBJS_DEEPKS=LCAO_deepks.o\
208208
deepks_orbpre.o\
209209
deepks_vdelta.o\
210210
deepks_vdpre.o\
211+
deepks_vdrpre.o\
211212
deepks_hmat.o\
212213
deepks_pdm.o\
213214
deepks_phialpha.o\

source/module_hamilt_lcao/module_deepks/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ if(ENABLE_DEEPKS)
1111
deepks_orbpre.cpp
1212
deepks_vdelta.cpp
1313
deepks_vdpre.cpp
14+
deepks_vdrpre.cpp
1415
deepks_hmat.cpp
1516
deepks_pdm.cpp
1617
deepks_phialpha.cpp

source/module_hamilt_lcao/module_deepks/LCAO_deepks.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "deepks_spre.h"
1616
#include "deepks_vdelta.h"
1717
#include "deepks_vdpre.h"
18+
#include "deepks_vdrpre.h"
1819
#include "module_base/complexmatrix.h"
1920
#include "module_base/intarray.h"
2021
#include "module_base/matrix.h"

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -314,34 +314,18 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
314314
ofs_hr.close();
315315
}
316316

317-
const std::string file_vdrpre = PARAM.globalv.global_out_dir + "deepks_vdrpre.csr";
318-
std::vector<hamilt::HContainer<TR>*> h_deltaR_pre(inlmax);
319-
for (int i = 0; i < inlmax; i++)
320-
{
321-
h_deltaR_pre[i] = new hamilt::HContainer<TR>(*hR_tot);
322-
h_deltaR_pre[i]->set_zero();
323-
}
324-
// DeePKS_domain::cal_vdr_precalc<TR>();
325-
if (rank == 0)
326-
{
327-
std::ofstream ofs_hrp(file_vdrpre, std::ios::out);
328-
for (int iat = 0; iat < nat; iat++)
329-
{
330-
ofs_hrp << "- Index of atom: " << iat << std::endl;
331-
for (int nl = 0; nl < nlmax; nl++)
332-
{
333-
int inl = iat * nlmax + nl;
334-
ofs_hrp << "-- Index of nl: " << nl << std::endl;
335-
ofs_hrp << "Matrix Dimension of H_delta(R): " << h_deltaR_pre[inl]->get_nbasis() << std::endl;
336-
ofs_hrp << "Matrix number of H_delta(R): " << h_deltaR_pre[inl]->size_R_loop() << std::endl;
337-
hamilt::Output_HContainer<TR> out_hrp(h_deltaR_pre[inl], ofs_hrp, sparse_threshold, precision);
338-
out_hrp.write();
339-
ofs_hrp << std::endl;
340-
}
341-
ofs_hrp << std::endl;
342-
}
343-
ofs_hrp.close();
344-
}
317+
torch::Tensor phialpha_r_out;
318+
torch::Tensor R_query;
319+
DeePKS_domain::prepare_phialpha_r(nlocal, lmaxd, inlmax, nat, phialpha, ucell, orb, *ParaV, GridD, phialpha_r_out, R_query);
320+
const std::string file_phialpha_r = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
321+
const std::string file_R_query = PARAM.globalv.global_out_dir + "deepks_R_query.npy";
322+
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
323+
LCAO_deepks_io::save_tensor2npy<int>(file_R_query, R_query, rank);
324+
325+
torch::Tensor gevdm_out;
326+
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
327+
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
328+
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
345329
}
346330
}
347331

source/module_hamilt_lcao/module_deepks/LCAO_deepks_io.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,18 +275,18 @@ void LCAO_deepks_io::save_tensor2npy(const std::string& file_name, const torch::
275275

276276
std::vector<T> data(tensor.numel());
277277

278-
if constexpr (std::is_same<T, double>::value)
279-
{
280-
std::memcpy(data.data(), tensor.data_ptr<double>(), tensor.numel() * sizeof(double));
281-
}
282-
else
278+
if constexpr (std::is_same<T, std::complex<double>>::value)
283279
{
284280
auto tensor_data = tensor.data_ptr<c10::complex<double>>();
285281
for (size_t i = 0; i < tensor.numel(); ++i)
286282
{
287283
data[i] = std::complex<double>(tensor_data[i].real(), tensor_data[i].imag());
288284
}
289285
}
286+
else
287+
{
288+
std::memcpy(data.data(), tensor.data_ptr<T>(), tensor.numel() * sizeof(T));
289+
}
290290

291291
npy::SaveArrayAsNumpy(file_name, false, shape.size(), shape.data(), data);
292292
}
@@ -313,6 +313,10 @@ template void LCAO_deepks_io::save_npy_h<std::complex<double>>(const std::vector
313313
const int nks,
314314
const int rank);
315315

316+
template void LCAO_deepks_io::save_tensor2npy<int>(const std::string& file_name,
317+
const torch::Tensor& tensor,
318+
const int rank);
319+
316320
template void LCAO_deepks_io::save_tensor2npy<double>(const std::string& file_name,
317321
const torch::Tensor& tensor,
318322
const int rank);

source/module_hamilt_lcao/module_deepks/deepks_vdpre.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ void DeePKS_domain::cal_v_delta_precalc(const int nlocal,
220220
std::vector<torch::Tensor> v_delta_precalc_vector;
221221
for (int nl = 0; nl < nlmax; ++nl)
222222
{
223-
torch::Tensor gevdm_complex = gevdm[nl].to(dtype);
224-
v_delta_precalc_vector.push_back(at::einsum("kxyamn, avmn->kxyav", {v_delta_pdm_vector[nl], gevdm[nl]}));
223+
torch::Tensor gevdm_totype = gevdm[nl].to(dtype);
224+
v_delta_precalc_vector.push_back(at::einsum("kxyamn, avmn->kxyav", {v_delta_pdm_vector[nl], gevdm_totype}));
225225
}
226226

227227
v_delta_precalc = torch::cat(v_delta_precalc_vector, -1);
@@ -296,6 +296,8 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
296296
int nlmax = inlmax / nat;
297297
int mmax = 2 * lmaxd + 1;
298298
phialpha_out = torch::zeros({nat, nlmax, nks, nlocal, mmax}, dtype);
299+
auto accessor
300+
= phialpha_out.accessor<std::conditional_t<std::is_same<TK, double>::value, double, c10::complex<double>>, 5>();
299301

300302
DeePKS_domain::iterate_ad1(
301303
ucell,
@@ -348,13 +350,13 @@ void DeePKS_domain::prepare_phialpha(const int nlocal,
348350
{
349351
if constexpr (std::is_same<TK, double>::value)
350352
{
351-
phialpha_out[iat][nl][ik][iw1_all][m1] = overlap->get_value(iw1, ib + m1);
353+
accessor[iat][nl][ik][iw1_all][m1] = overlap->get_value(iw1, ib + m1);
352354
}
353355
else
354356
{
355357
c10::complex<double> tmp;
356358
tmp = overlap->get_value(iw1, ib + m1) * kphase;
357-
phialpha_out.index_put_({iat, nl, ik, iw1_all, m1}, tmp);
359+
accessor[iat][nl][ik][iw1_all][m1] += tmp;
358360
}
359361
}
360362
ib += nm;
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// prepare_phialpha_r : prepare phialpha_r for outputting npy file
2+
3+
#ifdef __DEEPKS
4+
5+
#include "deepks_vdrpre.h"
6+
7+
#include "LCAO_deepks_io.h" // mohan add 2024-07-22
8+
#include "deepks_iterate.h"
9+
#include "module_base/blas_connector.h"
10+
#include "module_base/constants.h"
11+
#include "module_base/libm/libm.h"
12+
#include "module_base/parallel_reduce.h"
13+
#include "module_hamilt_lcao/module_hcontainer/atom_pair.h"
14+
#include "module_parameter/parameter.h"
15+
16+
void DeePKS_domain::prepare_phialpha_r(const int nlocal,
17+
const int lmaxd,
18+
const int inlmax,
19+
const int nat,
20+
const std::vector<hamilt::HContainer<double>*> phialpha,
21+
const UnitCell& ucell,
22+
const LCAO_Orbitals& orb,
23+
const Parallel_Orbitals& pv,
24+
const Grid_Driver& GridD,
25+
torch::Tensor& phialpha_r_out,
26+
torch::Tensor& R_query)
27+
{
28+
ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_r");
29+
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
30+
constexpr torch::Dtype dtype = torch::kFloat64;
31+
int nlmax = inlmax / nat;
32+
int mmax = 2 * lmaxd + 1;
33+
auto size_R = static_cast<long>(phialpha[0]->size_R_loop());
34+
phialpha_r_out = torch::zeros({size_R, nat, nlmax, nlocal, mmax}, dtype);
35+
R_query = torch::zeros({size_R, 3}, torch::kInt32);
36+
auto accessor = phialpha_r_out.accessor<double, 5>();
37+
auto R_accessor = R_query.accessor<int, 2>();
38+
39+
for (int iR = 0; iR < size_R; ++iR)
40+
{
41+
phialpha[0]->loop_R(iR, R_accessor[iR][0], R_accessor[iR][1], R_accessor[iR][2]);
42+
}
43+
44+
DeePKS_domain::iterate_ad1(
45+
ucell,
46+
GridD,
47+
orb,
48+
false, // no trace_alpha
49+
[&](const int iat,
50+
const ModuleBase::Vector3<double>& tau0,
51+
const int ibt,
52+
const ModuleBase::Vector3<double>& tau,
53+
const int start,
54+
const int nw_tot,
55+
ModuleBase::Vector3<int> dR)
56+
{
57+
if (phialpha[0]->find_matrix(iat, ibt, dR.x, dR.y, dR.z) == nullptr)
58+
{
59+
return; // to next loop
60+
}
61+
62+
// middle loop : all atomic basis on the adjacent atom ad
63+
for (int iw1 = 0; iw1 < nw_tot; ++iw1)
64+
{
65+
const int iw1_all = start + iw1;
66+
const int iw1_local = pv.global2local_row(iw1_all);
67+
const int iw2_local = pv.global2local_col(iw1_all);
68+
if (iw1_local < 0 || iw2_local < 0)
69+
{
70+
continue;
71+
}
72+
hamilt::BaseMatrix<double>* overlap = phialpha[0]->find_matrix(iat, ibt, dR);
73+
const int iR = phialpha[0]->find_R(dR);
74+
75+
int ib = 0;
76+
int nl = 0;
77+
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
78+
{
79+
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
80+
{
81+
const int nm = 2 * L0 + 1;
82+
for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
83+
{
84+
accessor[iR][iat][nl][iw1_all][m1] += overlap->get_value(iw1, ib + m1);
85+
}
86+
ib += nm;
87+
nl++;
88+
}
89+
}
90+
} // end iw
91+
}
92+
);
93+
94+
#ifdef __MPI
95+
int size = size_R * nat * nlmax * nlocal * mmax;
96+
double* data_ptr = phialpha_r_out.data_ptr<double>();
97+
Parallel_Reduce::reduce_all(data_ptr, size);
98+
99+
#endif
100+
101+
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
102+
return;
103+
}
104+
#endif
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#ifndef DEEPKS_VDRPRE_H
2+
#define DEEPKS_VDRPRE_H
3+
4+
#ifdef __DEEPKS
5+
6+
#include "module_base/complexmatrix.h"
7+
#include "module_base/intarray.h"
8+
#include "module_base/matrix.h"
9+
#include "module_base/timer.h"
10+
#include "module_basis/module_ao/parallel_orbitals.h"
11+
#include "module_basis/module_nao/two_center_integrator.h"
12+
#include "module_cell/module_neighbor/sltk_grid_driver.h"
13+
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
14+
15+
#include <torch/script.h>
16+
#include <torch/torch.h>
17+
18+
namespace DeePKS_domain
19+
{
20+
//------------------------
21+
// deepks_vdrpre.cpp
22+
//------------------------
23+
24+
// This file contains 1 subroutine for calculating v_delta,
25+
// cal_vdr_precalc : v_delta_r_precalc is used for training with v_delta_r label,
26+
// which equals gevdm * v_delta_pdm,
27+
// v_delta_pdm = overlap * overlap
28+
29+
// for deepks_v_delta = -1
30+
// calculates v_delta_r_precalc
31+
void prepare_phialpha_r(const int nlocal,
32+
const int lmaxd,
33+
const int inlmax,
34+
const int nat,
35+
const std::vector<hamilt::HContainer<double>*> phialpha,
36+
const UnitCell& ucell,
37+
const LCAO_Orbitals& orb,
38+
const Parallel_Orbitals& pv,
39+
const Grid_Driver& GridD,
40+
torch::Tensor& phialpha_r_out,
41+
torch::Tensor& R_query);
42+
43+
} // namespace DeePKS_domain
44+
#endif
45+
#endif

0 commit comments

Comments
 (0)