Skip to content

Commit e667040

Browse files
authored
Refactor&Feature: Package some parameters in DeePKS for simplification & update HR precalc output. (#6706)
* Package some parameters in DeePKS. * Simplify fpre. * Simplify orbpre. * Simplify pdm. * Simplify spre. * Simplify vdpre and vdrpre. * Simplify io. * Simplify force. * clang-format change. * clang-format change. * Change the output of HR precalc. * Update and add new test for DeePKS HR.
1 parent 74e0c5b commit e667040

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1959
-1659
lines changed

source/source_io/output_mat_sparse.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void output_mat_sparse(const bool& out_mat_hsR,
2727
//! generate a file containing the Hamiltonian and S(overlap) matrices
2828
if (out_mat_hsR)
2929
{
30-
output_HSR(ucell, istep, v_eff, pv, HS_Arrays, grid, kv, *p_dftu, p_ham);
30+
output_HSR(ucell, istep, pv, HS_Arrays, grid, kv, *p_dftu, p_ham);
3131
}
3232

3333
//! generate a file containing the kinetic energy matrix

source/source_io/write_HS_R.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
template <typename TK>
1616
void ModuleIO::output_HSR(const UnitCell& ucell,
1717
const int& istep,
18-
const ModuleBase::matrix& v_eff,
1918
const Parallel_Orbitals& pv,
2019
LCAO_HS_Arrays& HS_Arrays,
2120
const Grid_Driver& grid, // mohan add 2024-04-06
@@ -304,7 +303,6 @@ void ModuleIO::output_TR(const int istep,
304303
template void ModuleIO::output_HSR<double>(
305304
const UnitCell& ucell,
306305
const int& istep,
307-
const ModuleBase::matrix& v_eff,
308306
const Parallel_Orbitals& pv,
309307
LCAO_HS_Arrays& HS_Arrays,
310308
const Grid_Driver& grid,
@@ -324,7 +322,6 @@ template void ModuleIO::output_HSR<double>(
324322
template void ModuleIO::output_HSR<std::complex<double>>(
325323
const UnitCell& ucell,
326324
const int& istep,
327-
const ModuleBase::matrix& v_eff,
328325
const Parallel_Orbitals& pv,
329326
LCAO_HS_Arrays& HS_Arrays,
330327
const Grid_Driver& grid,

source/source_io/write_HS_R.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ using TAC = std::pair<int, std::array<int, 3>>;
1515
template <typename TK>
1616
void output_HSR(const UnitCell& ucell,
1717
const int& istep,
18-
const ModuleBase::matrix& v_eff,
1918
const Parallel_Orbitals& pv,
2019
LCAO_HS_Arrays& HS_Arrays,
2120
const Grid_Driver& grid, // mohan add 2024-04-06

source/source_lcao/FORCE_gamma.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,10 @@ void Force_LCAO<double>::ftable(const bool isforce,
231231
gd,
232232
*this->ParaV,
233233
nks,
234+
deepks.ld.deepks_param,
234235
kv->kvec_d,
235236
deepks.ld.phialpha,
236237
deepks.ld.gedm,
237-
deepks.ld.inl_index,
238238
fvnl_dalpha,
239239
isstress,
240240
svnl_dalpha);

source/source_lcao/FORCE_k.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,10 @@ void Force_LCAO<std::complex<double>>::ftable(const bool isforce,
257257
gd,
258258
pv,
259259
kv->get_nks(),
260+
deepks.ld.deepks_param,
260261
kv->kvec_d,
261262
deepks.ld.phialpha,
262263
deepks.ld.gedm,
263-
deepks.ld.inl_index,
264264
fvnl_dalpha,
265265
isstress,
266266
svnl_dalpha);

source/source_lcao/LCAO_hamilt.hpp

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include "source_base/abfs-vector3_order.h"
77
#include "source_base/global_variable.h"
88
#include "source_base/timer.h"
9-
#include "source_lcao/spar_exx.h"
109
#include "source_lcao/module_ri/RI_2D_Comm.h"
10+
#include "source_lcao/spar_exx.h"
1111

1212
#include <RI/global/Global_Func-2.h>
1313
#include <RI/ri/Cell_Nearest.h>
@@ -21,28 +21,27 @@
2121
// Peize Lin add 2022.09.13
2222

2323
template <typename Tdata>
24-
void sparse_format::cal_HR_exx(const UnitCell& ucell,
24+
void sparse_format::cal_HR_exx(
25+
const UnitCell& ucell,
2526
const Parallel_Orbitals& pv,
2627
LCAO_HS_Arrays& HS_Arrays,
2728
const int& current_spin,
2829
const double& sparse_threshold,
2930
const int (&nmp)[3],
30-
const std::vector<std::map<int,
31-
std::map<std::pair<int, std::array<int, 3>>,
32-
RI::Tensor<Tdata>>>>& Hexxs) {
31+
const std::vector<std::map<int, std::map<std::pair<int, std::array<int, 3>>, RI::Tensor<Tdata>>>>& Hexxs)
32+
{
3333
ModuleBase::TITLE("sparse_format", "cal_HR_exx");
3434
ModuleBase::timer::tick("sparse_format", "cal_HR_exx");
3535

3636
const Tdata frac = GlobalC::exx_info.info_global.hybrid_alpha;
3737

3838
std::map<int, std::array<double, 3>> atoms_pos;
39-
for (int iat = 0; iat < ucell.nat; ++iat) {
40-
atoms_pos[iat] = RI_Util::Vector3_to_array3(
41-
ucell.atoms[ucell.iat2it[iat]]
42-
.tau[ucell.iat2ia[iat]]);
39+
for (int iat = 0; iat < ucell.nat; ++iat)
40+
{
41+
atoms_pos[iat] = RI_Util::Vector3_to_array3(ucell.atoms[ucell.iat2it[iat]].tau[ucell.iat2ia[iat]]);
4342
}
4443
const std::array<std::array<double, 3>, 3> latvec
45-
= {RI_Util::Vector3_to_array3(ucell.a1), // too bad to use GlobalC here,
44+
= {RI_Util::Vector3_to_array3(ucell.a1), // too bad to use GlobalC here,
4645
RI_Util::Vector3_to_array3(ucell.a2),
4746
RI_Util::Vector3_to_array3(ucell.a3)};
4847

@@ -51,92 +50,83 @@ void sparse_format::cal_HR_exx(const UnitCell& ucell,
5150
RI::Cell_Nearest<int, int, 3, double, 3> cell_nearest;
5251
cell_nearest.init(atoms_pos, latvec, Rs_period);
5352

54-
const std::vector<int> is_list = (PARAM.inp.nspin != 4)
55-
? std::vector<int>{current_spin}
56-
: std::vector<int>{0, 1, 2, 3};
53+
const std::vector<int> is_list
54+
= (PARAM.inp.nspin != 4) ? std::vector<int>{current_spin} : std::vector<int>{0, 1, 2, 3};
5755

58-
for (const int is: is_list)
56+
for (const int is: is_list)
5957
{
6058
int is0_b = 0;
6159
int is1_b = 0;
6260
std::tie(is0_b, is1_b) = RI_2D_Comm::split_is_block(is);
6361

64-
if (Hexxs.empty())
62+
if (Hexxs.empty())
6563
{
6664
break;
6765
}
6866

69-
for (const auto& HexxA: Hexxs[is])
67+
for (const auto& HexxA: Hexxs[is])
7068
{
7169
const int iat0 = HexxA.first;
72-
for (const auto& HexxB: HexxA.second)
70+
for (const auto& HexxB: HexxA.second)
7371
{
7472
const int iat1 = HexxB.first.first;
7573

7674
const Abfs::Vector3_Order<int> R = RI_Util::array3_to_Vector3(
77-
cell_nearest.get_cell_nearest_discrete(iat0,
78-
iat1,
79-
HexxB.first.second));
75+
cell_nearest.get_cell_nearest_discrete(iat0, iat1, HexxB.first.second));
8076

8177
HS_Arrays.all_R_coor.insert(R);
8278

8379
const RI::Tensor<Tdata>& Hexx = HexxB.second;
8480

85-
for (size_t iw0 = 0; iw0 < Hexx.shape[0]; ++iw0)
81+
for (size_t iw0 = 0; iw0 < Hexx.shape[0]; ++iw0)
8682
{
87-
const int iwt0 = RI_2D_Comm::get_iwt(ucell,iat0, iw0, is0_b);
83+
const int iwt0 = RI_2D_Comm::get_iwt(ucell, iat0, iw0, is0_b);
8884
const int iwt0_local = pv.global2local_row(iwt0);
8985

90-
if (iwt0_local < 0)
86+
if (iwt0_local < 0)
9187
{
9288
continue;
9389
}
9490

95-
for (size_t iw1 = 0; iw1 < Hexx.shape[1]; ++iw1)
91+
for (size_t iw1 = 0; iw1 < Hexx.shape[1]; ++iw1)
9692
{
97-
const int iwt1 = RI_2D_Comm::get_iwt(ucell,iat1, iw1, is1_b);
93+
const int iwt1 = RI_2D_Comm::get_iwt(ucell, iat1, iw1, is1_b);
9894
const int iwt1_local = pv.global2local_col(iwt1);
9995

100-
if (iwt1_local < 0)
96+
if (iwt1_local < 0)
10197
{
10298
continue;
10399
}
104100

105-
if (std::abs(Hexx(iw0, iw1)) > sparse_threshold)
101+
if (std::abs(Hexx(iw0, iw1)) > sparse_threshold)
106102
{
107-
if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2)
103+
if (PARAM.inp.nspin == 1 || PARAM.inp.nspin == 2)
108104
{
109-
auto& HR_sparse_ptr
110-
= HS_Arrays
111-
.HR_sparse[current_spin][R][iwt0];
105+
auto& HR_sparse_ptr = HS_Arrays.HR_sparse[current_spin][R][iwt0];
112106
double& HR_sparse = HR_sparse_ptr[iwt1];
113-
HR_sparse += RI::Global_Func::convert<double>(
114-
frac * Hexx(iw0, iw1));
115-
if (std::abs(HR_sparse) <= sparse_threshold)
107+
HR_sparse += RI::Global_Func::convert<double>(frac * Hexx(iw0, iw1));
108+
if (std::abs(HR_sparse) <= sparse_threshold)
116109
{
117110
HR_sparse_ptr.erase(iwt1);
118111
}
119-
}
120-
else if (PARAM.inp.nspin == 4)
112+
}
113+
else if (PARAM.inp.nspin == 4)
121114
{
122-
auto& HR_sparse_ptr
123-
= HS_Arrays.HR_soc_sparse[R][iwt0];
124-
125-
std::complex<double>& HR_sparse
126-
= HR_sparse_ptr[iwt1];
127-
128-
HR_sparse += RI::Global_Func::convert<
129-
std::complex<double>>(frac * Hexx(iw0, iw1));
130-
131-
if (std::abs(HR_sparse) <= sparse_threshold)
115+
auto& HR_sparse_ptr = HS_Arrays.HR_soc_sparse[R][iwt0];
116+
117+
std::complex<double>& HR_sparse = HR_sparse_ptr[iwt1];
118+
119+
HR_sparse += RI::Global_Func::convert<std::complex<double>>(frac * Hexx(iw0, iw1));
120+
121+
if (std::abs(HR_sparse) <= sparse_threshold)
132122
{
133123
HR_sparse_ptr.erase(iwt1);
134124
}
135-
}
136-
else
125+
}
126+
else
137127
{
138128
throw std::invalid_argument(std::string(__FILE__) + " line "
139-
+ std::to_string(__LINE__));
129+
+ std::to_string(__LINE__));
140130
}
141131
}
142132
}

0 commit comments

Comments
 (0)