Skip to content

Commit b67545c

Browse files
authored
Merge branch 'develop' into move_fft
2 parents 2674706 + c931d97 commit b67545c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+7064
-46791
lines changed

source/source_lcao/module_deepks/deepks_basic.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ void DeePKS_domain::load_model(const std::string& model_file, torch::jit::script
6161
ModuleBase::TITLE("DeePKS_domain", "load_model");
6262
ModuleBase::timer::tick("DeePKS_domain", "load_model");
6363

64+
// check whether file exists
65+
std::ifstream ifs(model_file.c_str());
66+
if (!ifs)
67+
{
68+
ModuleBase::timer::tick("DeePKS_domain", "load_model");
69+
ModuleBase::WARNING_QUIT("DeePKS_domain::load_model", "No model file named " + model_file + ", please check!");
70+
return;
71+
}
72+
ifs.close();
6473
try
6574
{
6675
model = torch::jit::load(model_file);

source/source_lcao/module_deepks/deepks_vdrpre.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ void DeePKS_domain::cal_vdr_precalc(const int nlocal,
157157
int iRx = DeePKS_domain::mapping_R(dR.x);
158158
int iRy = DeePKS_domain::mapping_R(dR.y);
159159
int iRz = DeePKS_domain::mapping_R(dR.z);
160+
// Make sure the index is in range we need to save
161+
if (iRx >= R_size || iRy >= R_size || iRz >= R_size)
162+
{
163+
return; // to next loop
164+
}
160165

161166
for (int iw1 = 0; iw1 < nw1_tot; ++iw1)
162167
{

source/source_lcao/module_deepks/test/LCAO_deepks_test.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ void test_deepks<T>::check_vdrpre()
299299
std::vector<torch::Tensor> gevdm;
300300
torch::Tensor vdrpre;
301301
DeePKS_domain::cal_gevdm(ucell.nat, this->ld.inlmax, this->ld.inl2l, this->ld.pdm, gevdm);
302-
// normally use hR to get R_size, here use phialpha[0] only for test case
303-
int R_size = DeePKS_domain::get_R_size<double>(*(this->ld.phialpha[0]));
302+
// normally use hR to get R_size, here use 3 instead for Bravo lattice R in [-1,0,1]
303+
int R_size = 3;
304304
DeePKS_domain::cal_vdr_precalc(PARAM.sys.nlocal,
305305
this->ld.lmaxd,
306306
this->ld.inlmax,
@@ -317,8 +317,8 @@ void test_deepks<T>::check_vdrpre()
317317
ParaO,
318318
Test_Deepks::GridD,
319319
vdrpre);
320-
// vdrpre is large, we only check the main element in Bravo lattice vector (0, 0, 0)
321-
torch::Tensor vdrpre_sliced = vdrpre.slice(0, 0, 1, 1).slice(1, 0, 1, 1).slice(2, 0, 1, 1);
320+
// vdrpre is large, we only check the main element in Bravo lattice vector (0, 0, 0) and (1, 0, 0)
321+
torch::Tensor vdrpre_sliced = vdrpre.slice(0, 0, 2, 1).slice(1, 0, 1, 1).slice(2, 0, 1, 1);
322322
DeePKS_domain::check_tensor<double>(vdrpre_sliced, "vdr_precalc.dat", 0); // 0 for rank
323323
this->compare_with_ref("vdr_precalc.dat", "vdrpre_ref.dat");
324324
}
@@ -462,8 +462,8 @@ void test_deepks<T>::compare_with_ref(const std::string f1, const std::string f2
462462
file2 >> word2;
463463
if ((word1[0] - '0' >= 0 && word1[0] - '0' < 10) || word1[0] == '-')
464464
{
465-
double num1 = std::stof(word1);
466-
double num2 = std::stof(word2);
465+
double num1 = std::stod(word1);
466+
double num2 = std::stod(word2);
467467
if (std::abs(num1 - num2) > test_thr)
468468
{
469469
this->failed_check += 1;
@@ -476,10 +476,10 @@ void test_deepks<T>::compare_with_ref(const std::string f1, const std::string f2
476476
{
477477
std::string word1_str = word1.substr(1, word1.size() - 2);
478478
std::string word2_str = word2.substr(1, word2.size() - 2);
479-
double word1_real = std::stof(word1_str.substr(0, word1_str.find(',')));
480-
double word1_imag = std::stof(word1_str.substr(word1_str.find(',') + 1));
481-
double word2_real = std::stof(word2_str.substr(0, word2_str.find(',')));
482-
double word2_imag = std::stof(word2_str.substr(word2_str.find(',') + 1));
479+
double word1_real = std::stod(word1_str.substr(0, word1_str.find(',')));
480+
double word1_imag = std::stod(word1_str.substr(word1_str.find(',') + 1));
481+
double word2_real = std::stod(word2_str.substr(0, word2_str.find(',')));
482+
double word2_imag = std::stod(word2_str.substr(word2_str.find(',') + 1));
483483
if (std::abs(word1_real - word2_real) > test_thr || std::abs(word1_imag - word2_imag) > test_thr)
484484
{
485485
this->failed_check += 1;

source/source_lcao/module_ri/exx_opt_orb.cpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -174,22 +174,19 @@ void Exx_Opt_Orb::generate_matrix(
174174
const std::vector<std::vector<RI::Tensor<double>>> ms_abfs_abfs_I = cal_I( ms_abfs_abfs, T,I,T,I );
175175
// < lcaos lcaos | lcaos lcaos > - < lcaos lcaos | abfs > * < abfs | abfs >.I * < abfs | lcaos lcaos >
176176
const RI::Tensor<double> m_lcaoslcaos_lcaoslcaos_proj =
177-
cal_proj_22(
178-
ms_lcaoslcaos_lcaoslcaos.at(T).at(I).at(T).at(I),
177+
ms_lcaoslcaos_lcaoslcaos.at(T).at(I).at(T).at(I) - cal_mul_22(
179178
ms_lcaoslcaos_abfs.at(T).at(I).at(T).at(I),
180179
ms_abfs_abfs_I,
181180
ms_lcaoslcaos_abfs.at(T).at(I).at(T).at(I));
182181
// < lcaos lcaos | jys > - < lcaos lcaos | abfs > * < abfs | abfs >.I * < abfs | jys >
183182
const std::vector<RI::Tensor<double>> m_lcaoslcaos_jys_proj =
184-
{cal_proj_21(
185-
ms_lcaoslcaos_jys.at(T).at(I).at(T).at(I)[0],
183+
{ms_lcaoslcaos_jys.at(T).at(I).at(T).at(I)[0] - cal_mul_21(
186184
ms_lcaoslcaos_abfs.at(T).at(I).at(T).at(I),
187185
ms_abfs_abfs_I,
188186
{ms_jys_abfs.at(T).at(I).at(T).at(I)})};
189187
// < jys | jys > - < jys | abfs > * < abfs | abfs >.I * < abfs | jys >
190188
const std::vector<std::vector<RI::Tensor<double>>> m_jys_jys_proj =
191-
{{cal_proj_11(
192-
ms_jys_jys.at(T).at(I).at(T).at(I),
189+
{{ms_jys_jys.at(T).at(I).at(T).at(I) - cal_mul_11(
193190
{ms_jys_abfs.at(T).at(I).at(T).at(I)},
194191
ms_abfs_abfs_I,
195192
{ms_jys_abfs.at(T).at(I).at(T).at(I)})}};
@@ -228,42 +225,35 @@ void Exx_Opt_Orb::generate_matrix(
228225
const std::vector<std::vector<RI::Tensor<double>>> ms_abfs_abfs_I = cal_I( ms_abfs_abfs, TA,IA,TB,IB );
229226
// < lcaos lcaos | lcaos lcaos > - < lcaos lcaos | abfs > * < abfs | abfs >.I * < abfs | lcaos lcaos >
230227
const RI::Tensor<double> m_lcaoslcaos_lcaoslcaos_proj =
231-
cal_proj_22(
232-
ms_lcaoslcaos_lcaoslcaos.at(TA).at(IA).at(TB).at(IB),
228+
ms_lcaoslcaos_lcaoslcaos.at(TA).at(IA).at(TB).at(IB) - cal_mul_22(
233229
ms_lcaoslcaos_abfs.at(TA).at(IA).at(TB).at(IB),
234230
ms_abfs_abfs_I,
235231
ms_lcaoslcaos_abfs.at(TA).at(IA).at(TB).at(IB));
236232
// < lcaos lcaos | jys > - < lcaos lcaos | abfs > * < abfs | abfs >.I * < abfs | jys >
237233
const std::vector<RI::Tensor<double>> m_lcaoslcaos_jys_proj =
238-
{cal_proj_21(
239-
ms_lcaoslcaos_jys.at(TA).at(IA).at(TB).at(IB)[0],
234+
{ms_lcaoslcaos_jys.at(TA).at(IA).at(TB).at(IB)[0] - cal_mul_21(
240235
ms_lcaoslcaos_abfs.at(TA).at(IA).at(TB).at(IB),
241236
ms_abfs_abfs_I,
242237
{ ms_jys_abfs.at(TA).at(IA).at(TA).at(IA), ms_jys_abfs.at(TA).at(IA).at(TB).at(IB) }),
243-
cal_proj_21(
244-
ms_lcaoslcaos_jys.at(TA).at(IA).at(TB).at(IB)[1],
238+
ms_lcaoslcaos_jys.at(TA).at(IA).at(TB).at(IB)[1] - cal_mul_21(
245239
ms_lcaoslcaos_abfs.at(TA).at(IA).at(TB).at(IB),
246240
ms_abfs_abfs_I,
247241
{ ms_jys_abfs.at(TB).at(IB).at(TA).at(IA), ms_jys_abfs.at(TB).at(IB).at(TB).at(IB) })};
248242
// < jys | jys > - < jys | abfs > * < abfs | abfs >.I * < abfs | jys >
249243
const std::vector<std::vector<RI::Tensor<double>>> m_jys_jys_proj =
250-
{{cal_proj_11(
251-
ms_jys_jys.at(TA).at(IA).at(TA).at(IA),
244+
{{ms_jys_jys.at(TA).at(IA).at(TA).at(IA) - cal_mul_11(
252245
{ ms_jys_abfs.at(TA).at(IA).at(TA).at(IA), ms_jys_abfs.at(TA).at(IA).at(TB).at(IB) },
253246
ms_abfs_abfs_I,
254247
{ ms_jys_abfs.at(TA).at(IA).at(TA).at(IA), ms_jys_abfs.at(TA).at(IA).at(TB).at(IB) }),
255-
cal_proj_11(
256-
ms_jys_jys.at(TA).at(IA).at(TB).at(IB),
248+
ms_jys_jys.at(TA).at(IA).at(TB).at(IB) - cal_mul_11(
257249
{ ms_jys_abfs.at(TA).at(IA).at(TA).at(IA), ms_jys_abfs.at(TA).at(IA).at(TB).at(IB) },
258250
ms_abfs_abfs_I,
259251
{ ms_jys_abfs.at(TB).at(IB).at(TA).at(IA), ms_jys_abfs.at(TB).at(IB).at(TB).at(IB) }) },
260-
{cal_proj_11(
261-
ms_jys_jys.at(TB).at(IB).at(TA).at(IA),
252+
{ms_jys_jys.at(TB).at(IB).at(TA).at(IA) - cal_mul_11(
262253
{ ms_jys_abfs.at(TB).at(IB).at(TA).at(IA), ms_jys_abfs.at(TB).at(IB).at(TB).at(IB) },
263254
ms_abfs_abfs_I,
264255
{ ms_jys_abfs.at(TA).at(IA).at(TA).at(IA), ms_jys_abfs.at(TA).at(IA).at(TB).at(IB) }),
265-
cal_proj_11(
266-
ms_jys_jys.at(TB).at(IB).at(TB).at(IB),
256+
ms_jys_jys.at(TB).at(IB).at(TB).at(IB) - cal_mul_11(
267257
{ ms_jys_abfs.at(TB).at(IB).at(TA).at(IA), ms_jys_abfs.at(TB).at(IB).at(TB).at(IB) },
268258
ms_abfs_abfs_I,
269259
{ ms_jys_abfs.at(TB).at(IB).at(TA).at(IA), ms_jys_abfs.at(TB).at(IB).at(TB).at(IB) }) }};
@@ -301,86 +291,94 @@ void Exx_Opt_Orb::generate_matrix(
301291
}
302292
}
303293

304-
// m_big - m_left * m_middle * m_right.T
305-
RI::Tensor<double> Exx_Opt_Orb::cal_proj_22(
306-
const RI::Tensor<double> & m_big,
294+
// m_left * m_middle * m_right.T
295+
RI::Tensor<double> Exx_Opt_Orb::cal_mul_22(
307296
const std::vector<RI::Tensor<double>> & m_left,
308297
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
309298
const std::vector<RI::Tensor<double>> & m_right ) const
310299
{
311-
ModuleBase::TITLE("Exx_Opt_Orb::cal_proj_22");
312-
RI::Tensor<double> m_proj = m_big.copy();
300+
ModuleBase::TITLE("Exx_Opt_Orb::cal_mul_22");
301+
RI::Tensor<double> m_mul;
313302
for( size_t il=0; il!=m_left.size(); ++il )
314303
{
315304
for( size_t ir=0; ir!=m_right.size(); ++ir )
316305
{
317-
// m_proj = m_proj - m_left[il] * m_middle[il][ir] * m_right[ir].T;
306+
// m_mul += m_left[il] * m_middle[il][ir] * m_right[ir].T;
318307
const RI::Tensor<double> m_lm = RI::Tensor_Multiply::x0x1y1_x0x1a_ay1(m_left[il], m_middle[il][ir]);
319308
const RI::Tensor<double> m_lmr = RI::Tensor_Multiply::x0x1y0y1_x0x1a_y0y1a(m_lm, m_right[ir]);
320-
m_proj -= m_lmr;
309+
if(m_mul.empty())
310+
{ m_mul = std::move(m_lmr); }
311+
else
312+
{ m_mul += m_lmr; }
321313
}
322314
}
323-
return m_proj;
315+
return m_mul;
324316
}
325-
RI::Tensor<double> Exx_Opt_Orb::cal_proj_21(
326-
const RI::Tensor<double> & m_big,
317+
RI::Tensor<double> Exx_Opt_Orb::cal_mul_21(
327318
const std::vector<RI::Tensor<double>> & m_left,
328319
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
329320
const std::vector<RI::Tensor<double>> & m_right ) const
330321
{
331-
ModuleBase::TITLE("Exx_Opt_Orb::cal_proj_21");
332-
RI::Tensor<double> m_proj = m_big.copy();
322+
ModuleBase::TITLE("Exx_Opt_Orb::cal_mul_21");
323+
RI::Tensor<double> m_mul;
333324
for( size_t il=0; il!=m_left.size(); ++il )
334325
{
335326
for( size_t ir=0; ir!=m_right.size(); ++ir )
336327
{
337-
// m_proj = m_proj - m_left[il] * m_middle[il][ir] * m_right[ir].T;
328+
// m_mul += m_left[il] * m_middle[il][ir] * m_right[ir].T;
338329
const RI::Tensor<double> m_lm = RI::Tensor_Multiply::x0x1y1_x0x1a_ay1(m_left[il], m_middle[il][ir]);
339330
const RI::Tensor<double> m_lmr = RI::Tensor_Multiply::x0x1y0_x0x1a_y0a(m_lm, m_right[ir]);
340-
m_proj -= m_lmr;
331+
if(m_mul.empty())
332+
{ m_mul = std::move(m_lmr); }
333+
else
334+
{ m_mul += m_lmr; }
341335
}
342336
}
343-
return m_proj;
337+
return m_mul;
344338
}
345-
RI::Tensor<double> Exx_Opt_Orb::cal_proj_12(
346-
const RI::Tensor<double> & m_big,
339+
RI::Tensor<double> Exx_Opt_Orb::cal_mul_12(
347340
const std::vector<RI::Tensor<double>> & m_left,
348341
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
349342
const std::vector<RI::Tensor<double>> & m_right ) const
350343
{
351-
ModuleBase::TITLE("Exx_Opt_Orb::cal_proj_12");
352-
RI::Tensor<double> m_proj = m_big.copy();
344+
ModuleBase::TITLE("Exx_Opt_Orb::cal_mul_12");
345+
RI::Tensor<double> m_mul;
353346
for( size_t il=0; il!=m_left.size(); ++il )
354347
{
355348
for( size_t ir=0; ir!=m_right.size(); ++ir )
356349
{
357-
// m_proj = m_proj - m_left[il] * m_middle[il][ir] * m_right[ir].T;
350+
// m_mul += m_left[il] * m_middle[il][ir] * m_right[ir].T;
358351
const RI::Tensor<double> m_lm = RI::Tensor_Multiply::x0y1_x0a_ay1(m_left[il], m_middle[il][ir]);
359352
const RI::Tensor<double> m_lmr = RI::Tensor_Multiply::x0y0y1_x0a_y0y1a(m_lm, m_right[ir]);
360-
m_proj -= m_lmr;
353+
if(m_mul.empty())
354+
{ m_mul = std::move(m_lmr); }
355+
else
356+
{ m_mul += m_lmr; }
361357
}
362358
}
363-
return m_proj;
359+
return m_mul;
364360
}
365-
RI::Tensor<double> Exx_Opt_Orb::cal_proj_11(
366-
const RI::Tensor<double> & m_big,
361+
RI::Tensor<double> Exx_Opt_Orb::cal_mul_11(
367362
const std::vector<RI::Tensor<double>> & m_left,
368363
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
369364
const std::vector<RI::Tensor<double>> & m_right ) const
370365
{
371-
ModuleBase::TITLE("Exx_Opt_Orb::cal_proj_11");
372-
RI::Tensor<double> m_proj = m_big.copy();
366+
ModuleBase::TITLE("Exx_Opt_Orb::cal_mul_11");
367+
RI::Tensor<double> m_mul;
373368
for( size_t il=0; il!=m_left.size(); ++il )
374369
{
375370
for( size_t ir=0; ir!=m_right.size(); ++ir )
376371
{
377-
// m_proj = m_proj - m_left[il] * m_middle[il][ir] * m_right[ir].T;
372+
// m_mul += m_left[il] * m_middle[il][ir] * m_right[ir].T;
378373
const RI::Tensor<double> m_lm = RI::Tensor_Multiply::x0y1_x0a_ay1(m_left[il], m_middle[il][ir]);
379374
const RI::Tensor<double> m_lmr = RI::Tensor_Multiply::x0y0_x0a_y0a(m_lm, m_right[ir]);
380-
m_proj -= m_lmr;
375+
if(m_mul.empty())
376+
{ m_mul = std::move(m_lmr); }
377+
else
378+
{ m_mul += m_lmr; }
381379
}
382380
}
383-
return m_proj;
381+
return m_mul;
384382
}
385383

386384
std::vector<std::vector<RI::Tensor<double>>> Exx_Opt_Orb::cal_I(

source/source_lcao/module_ri/exx_opt_orb.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,19 @@ class Exx_Opt_Orb
2323
std::vector<std::vector<RI::Tensor<double>>> cal_I(
2424
const std::map<size_t,std::map<size_t,std::map<size_t,std::map<size_t,RI::Tensor<double>>>>> &ms,
2525
const size_t TA, const size_t IA, const size_t TB, const size_t IB ) const;
26-
RI::Tensor<double> cal_proj_22(
27-
const RI::Tensor<double> & m_big,
26+
RI::Tensor<double> cal_mul_22(
2827
const std::vector<RI::Tensor<double>> & m_left,
2928
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
3029
const std::vector<RI::Tensor<double>> & m_right ) const;
31-
RI::Tensor<double> cal_proj_21(
32-
const RI::Tensor<double> & m_big,
30+
RI::Tensor<double> cal_mul_21(
3331
const std::vector<RI::Tensor<double>> & m_left,
3432
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
3533
const std::vector<RI::Tensor<double>> & m_right ) const;
36-
RI::Tensor<double> cal_proj_12(
37-
const RI::Tensor<double> & m_big,
34+
RI::Tensor<double> cal_mul_12(
3835
const std::vector<RI::Tensor<double>> & m_left,
3936
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
4037
const std::vector<RI::Tensor<double>> & m_right ) const;
41-
RI::Tensor<double> cal_proj_11(
42-
const RI::Tensor<double> & m_big,
38+
RI::Tensor<double> cal_mul_11(
4339
const std::vector<RI::Tensor<double>> & m_left,
4440
const std::vector<std::vector<RI::Tensor<double>>> & m_middle,
4541
const std::vector<RI::Tensor<double>> & m_right ) const;

source/source_psi/psi.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ const int* Psi<T, Device>::get_ngk_pointer() const
301301
}
302302

303303
template <typename T, typename Device>
304-
const int& Psi<T, Device>::get_psi_bias() const
304+
const size_t& Psi<T, Device>::get_psi_bias() const
305305
{
306306
return this->psi_bias;
307307
}

source/source_psi/psi.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class Psi
130130
const Device* get_device() const;
131131

132132
// return psi_bias
133-
const int& get_psi_bias() const;
133+
const size_t& get_psi_bias() const;
134134

135135
const int& get_current_ngk() const;
136136

@@ -156,7 +156,7 @@ class Psi
156156
// current pointer for getting the psi
157157
mutable T* psi_current = nullptr;
158158
// psi_current = psi + psi_bias;
159-
mutable int psi_bias = 0;
159+
mutable size_t psi_bias = 0;
160160

161161
const int* ngk = nullptr;
162162

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
-0.007602894356
1+
-0.08058091803
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
-0.1377737748
1+
-0.3135463017
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
-0.0003706779237 8.336190151e-06 0.001631126762
2-
9.708686281e-05 1.690692748e-05 -0.005582105144
3-
-0.003585987888 2.035424355e-05 0.001126977947
4-
0.001908170566 0.00362839126 0.001398407512
5-
0.001951408382 -0.003673988622 0.001425592923
1+
-0.001966690223 0.0003863484866 -0.0002775679931
2+
0.0001490387822 0.0001336996716 -0.010866736
3+
-0.009536649342 -0.0001507865891 0.002944084409
4+
0.005665599528 0.00917496669 0.003914146719
5+
0.005688701256 -0.009544228259 0.004286072869

0 commit comments

Comments
 (0)