Skip to content

Commit 49b5005

Browse files
linpeizePeizeLin
authored andcommitted
Refactor: move Exx_Abfs::Abfs_Index to Element_Basis_Index (deepmodeling#6622)
* Refactor: move exx_abfs-abfs_index to element_basis_index * Refactor: move Element_Basis_Index::construct_range() to module_ao --------- Co-authored-by: linpz <[email protected]>
1 parent 7265f4a commit 49b5005

File tree

15 files changed

+131
-95
lines changed

15 files changed

+131
-95
lines changed

source/source_base/element_basis_index.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,26 @@
44
//==========================================================
55

66
#include "element_basis_index.h"
7+
78
namespace ModuleBase
89
{
910

10-
Element_Basis_Index::IndexLNM Element_Basis_Index::construct_index( const Range &range )
11+
Element_Basis_Index::IndexLNM
12+
Element_Basis_Index::construct_index( const Range &range )
1113
{
1214
IndexLNM index;
1315
index.resize( range.size() );
14-
for( size_t T=0; T!=range.size(); ++T )
16+
for( std::size_t T=0; T!=range.size(); ++T )
1517
{
16-
size_t count=0;
18+
std::size_t count=0;
1719
index[T].resize( range[T].size() );
18-
for( size_t L=0; L!=range[T].size(); ++L )
20+
for( std::size_t L=0; L!=range[T].size(); ++L )
1921
{
2022
index[T][L].resize( range[T][L].N );
21-
for( size_t N=0; N!=range[T][L].N; ++N )
23+
for( std::size_t N=0; N!=range[T][L].N; ++N )
2224
{
2325
index[T][L][N].resize( range[T][L].M );
24-
for( size_t M=0; M!=range[T][L].M; ++M )
26+
for( std::size_t M=0; M!=range[T][L].M; ++M )
2527
{
2628
index[T][L][N][M] = count;
2729
++count;

source/source_base/element_basis_index.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,40 +8,41 @@
88

99
#include <cstddef>
1010
#include <vector>
11+
1112
namespace ModuleBase
1213
{
1314

14-
class Element_Basis_Index
15+
namespace Element_Basis_Index
1516
{
16-
private:
17-
17+
//private:
18+
1819
struct NM
1920
{
2021
public:
21-
size_t N;
22-
size_t M;
22+
std::size_t N;
23+
std::size_t M;
2324
};
24-
25-
class Index_TL: public std::vector<std::vector<size_t>>
25+
26+
class Index_TL: public std::vector<std::vector<std::size_t>>
2627
{
2728
public:
28-
size_t N;
29-
size_t M;
29+
std::size_t N;
30+
std::size_t M;
3031
};
31-
32+
3233
class Index_T: public std::vector<Index_TL>
3334
{
3435
public:
35-
size_t count_size;
36-
};
37-
38-
public:
39-
40-
typedef std::vector<std::vector<NM>> Range; // range[T][L]
36+
std::size_t count_size;
37+
};
38+
39+
//public:
40+
41+
typedef std::vector<std::vector<NM>> Range; // range[T][L]
4142
typedef std::vector<Index_T> IndexLNM; // index[T][L][N][M]
42-
43-
static IndexLNM construct_index( const Range &range );
44-
};
43+
44+
extern IndexLNM construct_index( const Range &range );
45+
}
4546

4647
}
4748

source/source_basis/module_ao/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ if(ENABLE_LCAO)
99
ORB_nonlocal_lm.cpp
1010
ORB_read.cpp
1111
parallel_orbitals.cpp
12+
element_basis_index-ORB.cpp
1213
)
1314

1415
if(ENABLE_COVERAGE)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "element_basis_index-ORB.h"
2+
3+
#include "ORB_read.h"
4+
#include "ORB_atomic_lm.h"
5+
6+
namespace ModuleBase
7+
{
8+
9+
ModuleBase::Element_Basis_Index::Range
10+
Element_Basis_Index::construct_range( const LCAO_Orbitals &orb )
11+
{
12+
ModuleBase::Element_Basis_Index::Range range;
13+
range.resize( orb.get_ntype() );
14+
for( std::size_t T=0; T!=range.size(); ++T )
15+
{
16+
range[T].resize( orb.Phi[T].getLmax()+1 );
17+
for( std::size_t L=0; L!=range[T].size(); ++L )
18+
{
19+
range[T][L].N = orb.Phi[T].getNchi(L);
20+
range[T][L].M = 2*L+1;
21+
}
22+
}
23+
return range;
24+
}
25+
26+
27+
ModuleBase::Element_Basis_Index::Range
28+
Element_Basis_Index::construct_range( const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &orb )
29+
{
30+
ModuleBase::Element_Basis_Index::Range range;
31+
range.resize( orb.size() );
32+
for( std::size_t T=0; T!=range.size(); ++T )
33+
{
34+
range[T].resize( orb[T].size() );
35+
for( std::size_t L=0; L!=range[T].size(); ++L )
36+
{
37+
range[T][L].N = orb[T][L].size();
38+
range[T][L].M = 2*L+1;
39+
}
40+
}
41+
return range;
42+
}
43+
44+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef ELEMENT_BASIS_INDEX_ORB_H
2+
#define ELEMENT_BASIS_INDEX_ORB_H
3+
4+
#include "../../source_base/element_basis_index.h"
5+
#include <vector>
6+
7+
class Numerical_Orbital_Lm;
8+
class LCAO_Orbitals;
9+
10+
namespace ModuleBase
11+
{
12+
13+
namespace Element_Basis_Index
14+
{
15+
extern Range construct_range( const LCAO_Orbitals &orb );
16+
17+
extern Range construct_range( const std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> &orb ); // orb[T][L][N]
18+
}
19+
20+
}
21+
22+
#endif

source/source_esolver/esolver_ks_lcao.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,24 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
108108
int get_nk() const { return this->kv.get_nks(); } // if pv.nk exists and is public
109109
int get_ncol() const { return pv.ncol; }
110110
int get_nrow() const { return pv.nrow; }
111+
112+
113+
public:
114+
const Record_adj & get_RA() const { return RA; }
115+
const Grid_Driver & get_gd() const { return gd; }
116+
const Parallel_Orbitals & get_pv() const { return pv; }
117+
const Gint_k & get_GK() const { return GK; }
118+
const Gint_Gamma & get_GG() const { return GG; }
119+
const Grid_Technique & get_GridT() const { return GridT; }
120+
#ifndef __OLD_GINT
121+
const std::unique_ptr<ModuleGint::GintInfo> & get_gint_info() const { return gint_info_; }
122+
#endif
123+
const TwoCenterBundle & get_two_center_bundle() const { return two_center_bundle_; }
124+
const rdmft::RDMFT<TK, TR> & get_rdmft_solver() const { return rdmft_solver; }
125+
const LCAO_Orbitals & get_orb() const { return orb_; }
126+
const ModuleBase::matrix & get_scs() const { return scs; }
127+
const Setup_DeePKS<TK> & get_deepks() const { return deepks; }
128+
const Exx_NAO<TK> & get_exx_nao() const { return exx_nao; }
111129
};
112130
} // namespace ModuleESolver
113131
#endif

source/source_lcao/hamilt_lcao.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,20 @@ class HamiltLCAO : public Hamilt<TK>
105105
{
106106
return this->hR;
107107
}
108+
const HContainer<TR>* getHR() const
109+
{
110+
return this->hR;
111+
}
108112

109113
/// get SR pointer of *this->sR, which is a HContainer<TR> and contains S(R)
110114
HContainer<TR>*& getSR()
111115
{
112116
return this->sR;
113117
}
118+
const HContainer<TR>* getSR() const
119+
{
120+
return this->sR;
121+
}
114122

115123
#ifdef __MLALGO
116124
/// get V_delta_R pointer of *this->V_delta_R, which is a HContainer<TR> and contains V_delta(R)

source/source_lcao/module_ri/ABFs_Construct-PCA.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#include "ABFs_Construct-PCA.h"
22

3-
#include "exx_abfs-abfs_index.h"
43
#include "../../source_base/module_external/lapack_connector.h"
54
#include "../../source_base/global_function.h"
6-
#include "../../source_base/element_basis_index.h"
5+
#include "../../source_basis/module_ao/element_basis_index-ORB.h"
76
#include "../../source_base/matrix.h"
87
#include "../../source_lcao/module_ri/Matrix_Orbs11.h"
98
#include "../../source_lcao/module_ri/Matrix_Orbs21.h"
@@ -81,12 +80,12 @@ namespace PCA
8180
ModuleBase::TITLE("ABFs_Construct::PCA::cal_PCA");
8281

8382
const ModuleBase::Element_Basis_Index::Range
84-
range_lcaos = Exx_Abfs::Abfs_Index::construct_range( lcaos );
83+
range_lcaos = ModuleBase::Element_Basis_Index::construct_range( lcaos );
8584
const ModuleBase::Element_Basis_Index::IndexLNM
8685
index_lcaos = ModuleBase::Element_Basis_Index::construct_index( range_lcaos );
8786

8887
const ModuleBase::Element_Basis_Index::Range
89-
range_abfs = Exx_Abfs::Abfs_Index::construct_range( abfs );
88+
range_abfs = ModuleBase::Element_Basis_Index::construct_range( abfs );
9089
const ModuleBase::Element_Basis_Index::IndexLNM
9190
index_abfs = ModuleBase::Element_Basis_Index::construct_index( range_abfs );
9291

source/source_lcao/module_ri/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ if (ENABLE_LIBRI)
1414
if(ENABLE_LCAO)
1515
list(APPEND objects
1616
conv_coulomb_pot_k.cpp
17-
exx_abfs-abfs_index.cpp
1817
exx_abfs-construct_orbs.cpp
1918
exx_abfs-io.cpp
2019
exx_abfs-jle.cpp

source/source_lcao/module_ri/LRI_CV.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
#include "LRI_CV.h"
1010
#include "LRI_CV_Tools.h"
11-
#include "exx_abfs-abfs_index.h"
1211
#include "exx_abfs-construct_orbs.h"
1312
#include "RI_Util.h"
13+
#include "../../source_basis/module_ao/element_basis_index-ORB.h"
1414
#include "../../source_base/tool_title.h"
1515
#include "../../source_base/timer.h"
1616
#include "../../source_pw/module_pwdft/global.h"
@@ -62,11 +62,11 @@ void LRI_CV<Tdata>::set_orbitals(
6262
= Exx_Abfs::Construct_Orbs::get_Rmax(this->abfs_ccp);
6363

6464
const ModuleBase::Element_Basis_Index::Range
65-
range_lcaos = Exx_Abfs::Abfs_Index::construct_range( lcaos );
65+
range_lcaos = ModuleBase::Element_Basis_Index::construct_range( lcaos );
6666
this->index_lcaos = ModuleBase::Element_Basis_Index::construct_index( range_lcaos );
6767

6868
const ModuleBase::Element_Basis_Index::Range
69-
range_abfs = Exx_Abfs::Abfs_Index::construct_range( abfs );
69+
range_abfs = ModuleBase::Element_Basis_Index::construct_range( abfs );
7070
this->index_abfs = ModuleBase::Element_Basis_Index::construct_index( range_abfs );
7171

7272
int Lmax_v = std::numeric_limits<double>::min();

0 commit comments

Comments
 (0)