Skip to content

Commit 3378c36

Browse files
committed
Add support for DeePKS Real space Hamiltonian.
1 parent 0d46c99 commit 3378c36

File tree

5 files changed

+291
-62
lines changed

5 files changed

+291
-62
lines changed

source/module_hamilt_lcao/module_deepks/LCAO_deepks_interface.cpp

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,13 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
274274
} // end deepks_out_labels == 1
275275
} // end bandgap label
276276

277-
// not add deepks_out_labels = 2 for HR yet
278277
// H(R) matrix part, for HR, base will not be calculated since they are HContainer objects
279278
if (PARAM.inp.deepks_v_delta < 0)
280279
{
281280
// set the output
282281
const double sparse_threshold = 1e-10;
283282
const int precision = 8;
284-
const std::string file_hrtot = PARAM.globalv.global_out_dir + "deepks_hrtot.csr";
283+
const std::string file_hrtot = PARAM.globalv.global_out_dir + (PARAM.inp.deepks_out_labels == 1 ? "deepks_hrtot.csr" : "hamiltonian_r.csr");
285284
hamilt::HContainer<TR>* hR_tot = (p_ham->getHR());
286285

287286
if (rank == 0)
@@ -290,47 +289,75 @@ void LCAO_Deepks_Interface<TK, TR>::out_deepks_labels(const double& etot,
290289
ofs_hr << "Matrix Dimension of H(R): " << hR_tot->get_nbasis() << std::endl;
291290
ofs_hr << "Matrix number of H(R): " << hR_tot->size_R_loop() << std::endl;
292291
hamilt::Output_HContainer<TR> out_hr(hR_tot, ofs_hr, sparse_threshold, precision);
293-
out_hr.write();
292+
out_hr.write(true); // write all the matrices, including empty ones
294293
ofs_hr.close();
295294
}
296295

297296
if (PARAM.inp.deepks_scf)
298297
{
299-
const std::string file_vdeltar = PARAM.globalv.global_out_dir + "deepks_hrdelta.csr";
300-
hamilt::HContainer<TR>* h_deltaR = p_ham->get_V_delta_R();
301-
302-
if (rank == 0)
298+
if (PARAM.inp.deepks_out_labels == 1)
303299
{
304-
std::ofstream ofs_hr(file_vdeltar, std::ios::out);
305-
ofs_hr << "Matrix Dimension of H_delta(R): " << h_deltaR->get_nbasis() << std::endl;
306-
ofs_hr << "Matrix number of H_delta(R): " << h_deltaR->size_R_loop() << std::endl;
307-
hamilt::Output_HContainer<TR> out_hr(h_deltaR, ofs_hr, sparse_threshold, precision);
308-
out_hr.write();
309-
ofs_hr.close();
310-
}
300+
const std::string file_vdeltar = PARAM.globalv.global_out_dir + "deepks_hrdelta.csr";
301+
hamilt::HContainer<TR>* h_deltaR = p_ham->get_V_delta_R();
302+
303+
if (rank == 0)
304+
{
305+
std::ofstream ofs_hr(file_vdeltar, std::ios::out);
306+
ofs_hr << "Matrix Dimension of H_delta(R): " << h_deltaR->get_nbasis() << std::endl;
307+
ofs_hr << "Matrix number of H_delta(R): " << h_deltaR->size_R_loop() << std::endl;
308+
hamilt::Output_HContainer<TR> out_hr(h_deltaR, ofs_hr, sparse_threshold, precision);
309+
out_hr.write(true); // write all the matrices, including empty ones
310+
ofs_hr.close();
311+
}
311312

312-
torch::Tensor phialpha_r_out;
313-
torch::Tensor R_query;
314-
DeePKS_domain::prepare_phialpha_r(nlocal,
315-
lmaxd,
316-
inlmax,
317-
nat,
318-
phialpha,
319-
ucell,
320-
orb,
321-
*ParaV,
322-
GridD,
323-
phialpha_r_out,
324-
R_query);
325-
const std::string file_phialpha_r = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
326-
const std::string file_R_query = PARAM.globalv.global_out_dir + "deepks_R_query.npy";
327-
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
328-
LCAO_deepks_io::save_tensor2npy<int>(file_R_query, R_query, rank);
329-
330-
torch::Tensor gevdm_out;
331-
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
332-
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
333-
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
313+
if (PARAM.inp.deepks_v_delta == -1)
314+
{
315+
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
316+
torch::Tensor vdr_precalc;
317+
DeePKS_domain::cal_vdr_precalc(nlocal,
318+
lmaxd,
319+
inlmax,
320+
nat,
321+
nks,
322+
R_size,
323+
inl2l,
324+
kvec_d,
325+
phialpha,
326+
gevdm,
327+
inl_index,
328+
ucell,
329+
orb,
330+
*ParaV,
331+
GridD,
332+
vdr_precalc);
333+
334+
const std::string file_vdrpre = PARAM.globalv.global_out_dir + "deepks_vdrpre.npy";
335+
LCAO_deepks_io::save_tensor2npy<double>(file_vdrpre, vdr_precalc, rank);
336+
}
337+
else if (PARAM.inp.deepks_v_delta == -2)
338+
{
339+
int R_size = DeePKS_domain::get_R_size(*h_deltaR);
340+
torch::Tensor phialpha_r_out;
341+
DeePKS_domain::prepare_phialpha_r(nlocal,
342+
lmaxd,
343+
inlmax,
344+
nat,
345+
R_size,
346+
phialpha,
347+
ucell,
348+
orb,
349+
*ParaV,
350+
GridD,
351+
phialpha_r_out);
352+
const std::string file_phialpha_r = PARAM.globalv.global_out_dir + "deepks_phialpha_r.npy";
353+
LCAO_deepks_io::save_tensor2npy<double>(file_phialpha_r, phialpha_r_out, rank);
354+
355+
torch::Tensor gevdm_out;
356+
DeePKS_domain::prepare_gevdm(nat, lmaxd, inlmax, orb, gevdm, gevdm_out);
357+
const std::string file_gevdm = PARAM.globalv.global_out_dir + "deepks_gevdm.npy";
358+
LCAO_deepks_io::save_tensor2npy<double>(file_gevdm, gevdm_out, rank);
359+
}
360+
}
334361
}
335362
}
336363

source/module_hamilt_lcao/module_deepks/deepks_vdrpre.cpp

Lines changed: 198 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,25 @@
1414
#include "module_parameter/parameter.h"
1515

1616
void DeePKS_domain::prepare_phialpha_r(const int nlocal,
17-
const int lmaxd,
18-
const int inlmax,
19-
const int nat,
20-
const std::vector<hamilt::HContainer<double>*> phialpha,
21-
const UnitCell& ucell,
22-
const LCAO_Orbitals& orb,
23-
const Parallel_Orbitals& pv,
24-
const Grid_Driver& GridD,
25-
torch::Tensor& phialpha_r_out,
26-
torch::Tensor& R_query)
17+
const int lmaxd,
18+
const int inlmax,
19+
const int nat,
20+
const int R_size,
21+
const std::vector<hamilt::HContainer<double>*> phialpha,
22+
const UnitCell& ucell,
23+
const LCAO_Orbitals& orb,
24+
const Parallel_Orbitals& pv,
25+
const Grid_Driver& GridD,
26+
torch::Tensor& phialpha_r_out)
2727
{
2828
ModuleBase::TITLE("DeePKS_domain", "prepare_phialpha_r");
2929
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
3030
constexpr torch::Dtype dtype = torch::kFloat64;
3131
int nlmax = inlmax / nat;
3232
int mmax = 2 * lmaxd + 1;
33-
auto size_R = static_cast<long>(phialpha[0]->size_R_loop());
34-
phialpha_r_out = torch::zeros({size_R, nat, nlmax, nlocal, mmax}, dtype);
35-
R_query = torch::zeros({size_R, 3}, torch::kInt32);
36-
auto accessor = phialpha_r_out.accessor<double, 5>();
37-
auto R_accessor = R_query.accessor<int, 2>();
3833

39-
for (int iR = 0; iR < size_R; ++iR)
40-
{
41-
phialpha[0]->loop_R(iR, R_accessor[iR][0], R_accessor[iR][1], R_accessor[iR][2]);
42-
}
34+
phialpha_r_out = torch::zeros({R_size, R_size, R_size, nat, nlmax, nlocal, mmax}, dtype);
35+
auto accessor = phialpha_r_out.accessor<double, 7>();
4336

4437
DeePKS_domain::iterate_ad1(
4538
ucell,
@@ -81,18 +74,22 @@ void DeePKS_domain::prepare_phialpha_r(const int nlocal,
8174
const int nm = 2 * L0 + 1;
8275
for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
8376
{
84-
accessor[iR][iat][nl][iw1_all][m1] += overlap->get_value(iw1, ib + m1);
77+
int iRx = DeePKS_domain::mapping_R(dR.x);
78+
int iRy = DeePKS_domain::mapping_R(dR.y);
79+
int iRz = DeePKS_domain::mapping_R(dR.z);
80+
accessor[iRx][iRy][iRz][iat][nl][iw1_all][m1]
81+
+= overlap->get_value(iw1, ib + m1);
8582
}
8683
ib += nm;
8784
nl++;
8885
}
8986
}
90-
} // end iw
87+
} // end iw
9188
}
9289
);
9390

9491
#ifdef __MPI
95-
int size = size_R * nat * nlmax * nlocal * mmax;
92+
int size = R_size * R_size * R_size * nat * nlmax * nlocal * mmax;
9693
double* data_ptr = phialpha_r_out.data_ptr<double>();
9794
Parallel_Reduce::reduce_all(data_ptr, size);
9895

@@ -101,4 +98,183 @@ void DeePKS_domain::prepare_phialpha_r(const int nlocal,
10198
ModuleBase::timer::tick("DeePKS_domain", "prepare_phialpha_r");
10299
return;
103100
}
101+
102+
void DeePKS_domain::cal_vdr_precalc(const int nlocal,
103+
const int lmaxd,
104+
const int inlmax,
105+
const int nat,
106+
const int nks,
107+
const int R_size,
108+
const std::vector<int>& inl2l,
109+
const std::vector<ModuleBase::Vector3<double>>& kvec_d,
110+
const std::vector<hamilt::HContainer<double>*> phialpha,
111+
const std::vector<torch::Tensor> gevdm,
112+
const ModuleBase::IntArray* inl_index,
113+
const UnitCell& ucell,
114+
const LCAO_Orbitals& orb,
115+
const Parallel_Orbitals& pv,
116+
const Grid_Driver& GridD,
117+
torch::Tensor& vdr_precalc)
118+
{
119+
ModuleBase::TITLE("DeePKS_domain", "calc_vdr_precalc");
120+
ModuleBase::timer::tick("DeePKS_domain", "calc_vdr_precalc");
121+
122+
torch::Tensor vdr_pdm
123+
= torch::zeros({R_size, R_size, R_size, nlocal, nlocal, inlmax, (2 * lmaxd + 1), (2 * lmaxd + 1)},
124+
torch::TensorOptions().dtype(torch::kFloat64));
125+
auto accessor = vdr_pdm.accessor<double, 8>();
126+
127+
DeePKS_domain::iterate_ad2(
128+
ucell,
129+
GridD,
130+
orb,
131+
false, // no trace_alpha
132+
[&](const int iat,
133+
const ModuleBase::Vector3<double>& tau0,
134+
const int ibt1,
135+
const ModuleBase::Vector3<double>& tau1,
136+
const int start1,
137+
const int nw1_tot,
138+
ModuleBase::Vector3<int> dR1,
139+
const int ibt2,
140+
const ModuleBase::Vector3<double>& tau2,
141+
const int start2,
142+
const int nw2_tot,
143+
ModuleBase::Vector3<int> dR2)
144+
{
145+
const int T0 = ucell.iat2it[iat];
146+
const int I0 = ucell.iat2ia[iat];
147+
if (phialpha[0]->find_matrix(iat, ibt1, dR1.x, dR1.y, dR1.z) == nullptr
148+
|| phialpha[0]->find_matrix(iat, ibt2, dR2.x, dR2.y, dR2.z) == nullptr)
149+
{
150+
return; // to next loop
151+
}
152+
153+
hamilt::BaseMatrix<double>* overlap_1 = phialpha[0]->find_matrix(iat, ibt1, dR1);
154+
hamilt::BaseMatrix<double>* overlap_2 = phialpha[0]->find_matrix(iat, ibt2, dR2);
155+
assert(overlap_1->get_col_size() == overlap_2->get_col_size());
156+
ModuleBase::Vector3<int> dR = dR1 - dR2;
157+
int iRx = DeePKS_domain::mapping_R(dR.x);
158+
int iRy = DeePKS_domain::mapping_R(dR.y);
159+
int iRz = DeePKS_domain::mapping_R(dR.z);
160+
161+
for (int iw1 = 0; iw1 < nw1_tot; ++iw1)
162+
{
163+
const int iw1_all = start1 + iw1; // this is \mu
164+
const int iw1_local = pv.global2local_row(iw1_all);
165+
if (iw1_local < 0)
166+
{
167+
continue;
168+
}
169+
for (int iw2 = 0; iw2 < nw2_tot; ++iw2)
170+
{
171+
const int iw2_all = start2 + iw2; // this is \nu
172+
const int iw2_local = pv.global2local_col(iw2_all);
173+
if (iw2_local < 0)
174+
{
175+
continue;
176+
}
177+
178+
int ib = 0;
179+
for (int L0 = 0; L0 <= orb.Alpha[0].getLmax(); ++L0)
180+
{
181+
for (int N0 = 0; N0 < orb.Alpha[0].getNchi(L0); ++N0)
182+
{
183+
const int inl = inl_index[T0](I0, L0, N0);
184+
const int nm = 2 * L0 + 1;
185+
186+
for (int m1 = 0; m1 < nm; ++m1) // nm = 1 for s, 3 for p, 5 for d
187+
{
188+
for (int m2 = 0; m2 < nm; ++m2) // nm = 1 for s, 3 for p, 5 for d
189+
{
190+
double tmp = overlap_1->get_value(iw1, ib + m1)
191+
* overlap_2->get_value(iw2, ib + m2);
192+
accessor[iRx][iRy][iRz][iw1_all][iw2_all][inl][m1][m2]
193+
+= tmp;
194+
}
195+
}
196+
ib += nm;
197+
}
198+
}
199+
} // iw2
200+
} // iw1
201+
}
202+
);
203+
204+
#ifdef __MPI
205+
const int size = R_size * R_size * R_size * nlocal * nlocal * inlmax * (2 * lmaxd + 1) * (2 * lmaxd + 1);
206+
double* data_ptr = vdr_pdm.data_ptr<double>();
207+
Parallel_Reduce::reduce_all(data_ptr, size);
208+
#endif
209+
210+
// transfer v_delta_pdm to v_delta_pdm_vector
211+
int nlmax = inlmax / nat;
212+
std::vector<torch::Tensor> vdr_pdm_vector;
213+
for (int nl = 0; nl < nlmax; ++nl)
214+
{
215+
int nm = 2 * inl2l[nl] + 1;
216+
torch::Tensor vdr_pdm_sliced = vdr_pdm.slice(5, nl, inlmax, nlmax).slice(6, 0, nm, 1).slice(7, 0, nm, 1);
217+
vdr_pdm_vector.push_back(vdr_pdm_sliced);
218+
}
219+
220+
assert(vdr_pdm_vector.size() == nlmax);
221+
222+
// einsum for each nl:
223+
std::vector<torch::Tensor> vdr_vector;
224+
for (int nl = 0; nl < nlmax; ++nl)
225+
{
226+
vdr_vector.push_back(at::einsum("pqrxyamn, avmn->pqrxyav", {vdr_pdm_vector[nl], gevdm[nl]}));
227+
}
228+
229+
vdr_precalc = torch::cat(vdr_vector, -1);
230+
231+
ModuleBase::timer::tick("DeePKS_domain", "calc_vdr_precalc");
232+
return;
233+
}
234+
235+
int DeePKS_domain::mapping_R(int R)
236+
{
237+
// R_index mapping: index(R) = 2R-1 if R > 0, index(R) = -2R if R <= 0
238+
// after mapping, the new index [0,1,2,3,4,...] -> old index [0,1,-1,2,-2,...]
239+
// This manipulation makes sure that the new index is natural number
240+
// which makes it available to be used as index in torch::Tensor
241+
int R_index = 0;
242+
if (R > 0)
243+
{
244+
R_index = 2 * R - 1;
245+
}
246+
else
247+
{
248+
R_index = -2 * R;
249+
}
250+
return R_index;
251+
}
252+
253+
template <typename T>
254+
int DeePKS_domain::get_R_size(const hamilt::HContainer<T>& hcontainer)
255+
{
256+
// get R_size from hcontainer
257+
int R_size = 0;
258+
if (hcontainer.size_R_loop() > 0)
259+
{
260+
for (int iR = 0; iR < hcontainer.size_R_loop(); ++iR)
261+
{
262+
ModuleBase::Vector3<int> R_vec;
263+
hcontainer.loop_R(iR, R_vec.x, R_vec.y, R_vec.z);
264+
int R_min = std::min({R_vec.x, R_vec.y, R_vec.z});
265+
int R_max = std::max({R_vec.x, R_vec.y, R_vec.z});
266+
int tmp_R_size = std::max(DeePKS_domain::mapping_R(R_min), DeePKS_domain::mapping_R(R_max)) + 1;
267+
if (tmp_R_size > R_size)
268+
{
269+
R_size = tmp_R_size;
270+
}
271+
}
272+
}
273+
assert(R_size > 0);
274+
return R_size;
275+
}
276+
277+
template int DeePKS_domain::get_R_size<double>(const hamilt::HContainer<double>& hcontainer);
278+
template int DeePKS_domain::get_R_size<std::complex<double>>(
279+
const hamilt::HContainer<std::complex<double>>& hcontainer);
104280
#endif

0 commit comments

Comments
 (0)