Skip to content

Commit da1baf5

Browse files
authored
Refactor: make get_S a new esolver (#5515)
* Refactor: make get_S a new esolver * add head files * add esolver_gets
1 parent f68dc9f commit da1baf5

File tree

10 files changed

+273
-191
lines changed

10 files changed

+273
-191
lines changed

source/Makefile.Objects

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ OBJS_ESOLVER_LCAO=esolver_ks_lcao.o\
255255
dpks_cal_e_delta_band.o\
256256
set_matrix_grid.o\
257257
lcao_before_scf.o\
258-
lcao_gets.o\
258+
esolver_gets.o\
259259
lcao_others.o\
260260
lcao_init_after_vc.o\
261261

source/driver_run.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,16 @@ void Driver::driver_run() {
6767
Relax_Driver rl_driver;
6868
rl_driver.relax_driver(p_esolver);
6969
}
70+
else if (cal_type == "get_S")
71+
{
72+
p_esolver->runner(0, GlobalC::ucell);
73+
}
7074
else
7175
{
7276
//! supported "other" functions:
7377
//! get_pchg(LCAO),
7478
//! test_memory(PW,LCAO),
7579
//! test_neighbour(LCAO),
76-
//! get_S(LCAO),
7780
//! gen_bessel(PW), et al.
7881
const int istep = 0;
7982
p_esolver->others(istep);

source/module_esolver/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ if(ENABLE_LCAO)
2121
dpks_cal_e_delta_band.cpp
2222
set_matrix_grid.cpp
2323
lcao_before_scf.cpp
24-
lcao_gets.cpp
24+
esolver_gets.cpp
2525
lcao_others.cpp
2626
lcao_init_after_vc.cpp
2727
)

source/module_esolver/esolver.cpp

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
#include "module_base/module_device/device.h"
66
#include "module_parameter/parameter.h"
77
#ifdef __LCAO
8-
#include "esolver_ks_lcaopw.h"
8+
#include "esolver_gets.h"
99
#include "esolver_ks_lcao.h"
1010
#include "esolver_ks_lcao_tddft.h"
11+
#include "esolver_ks_lcaopw.h"
1112
#include "module_lr/esolver_lrtd_lcao.h"
1213
extern "C"
1314
{
@@ -188,18 +189,39 @@ ESolver* init_esolver(const Input_para& inp, UnitCell& ucell)
188189
{
189190
if (PARAM.globalv.gamma_only_local)
190191
{
191-
return new ESolver_KS_LCAO<double, double>();
192-
}
193-
else if (PARAM.inp.nspin < 4)
194-
{
195-
return new ESolver_KS_LCAO<std::complex<double>, double>();
196-
}
197-
else
198-
{
199-
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
200-
}
201-
}
202-
else if (esolver_type == "ksdft_lcao_tddft")
192+
if (PARAM.inp.calculation == "get_S")
193+
{
194+
return new ESolver_GetS<double, double>();
195+
}
196+
else
197+
{
198+
return new ESolver_KS_LCAO<double, double>();
199+
}
200+
}
201+
else if (PARAM.inp.nspin < 4)
202+
{
203+
if (PARAM.inp.calculation == "get_S")
204+
{
205+
return new ESolver_GetS<std::complex<double>, double>();
206+
}
207+
else
208+
{
209+
return new ESolver_KS_LCAO<std::complex<double>, double>();
210+
}
211+
}
212+
else
213+
{
214+
if (PARAM.inp.calculation == "get_S")
215+
{
216+
return new ESolver_GetS<std::complex<double>, std::complex<double>>();
217+
}
218+
else
219+
{
220+
return new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
221+
}
222+
}
223+
}
224+
else if (esolver_type == "ksdft_lcao_tddft")
203225
{
204226
return new ESolver_KS_LCAO_TDDFT();
205227
}
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#include "esolver_gets.h"
2+
3+
#include "module_base/timer.h"
4+
#include "module_cell/module_neighbor/sltk_atom_arrange.h"
5+
#include "module_elecstate/elecstate_lcao.h"
6+
#include "module_hamilt_lcao/hamilt_lcaodft/LCAO_domain.h"
7+
#include "module_hamilt_lcao/hamilt_lcaodft/hamilt_lcao.h"
8+
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.h"
9+
#include "module_io/print_info.h"
10+
#include "module_io/write_HS_R.h"
11+
12+
namespace ModuleESolver
13+
{
14+
15+
template <typename TK, typename TR>
16+
ESolver_GetS<TK, TR>::ESolver_GetS()
17+
{
18+
this->classname = "ESolver_GetS";
19+
this->basisname = "LCAO";
20+
}
21+
22+
template <typename TK, typename TR>
23+
ESolver_GetS<TK, TR>::~ESolver_GetS()
24+
{
25+
}
26+
27+
template <typename TK, typename TR>
28+
void ESolver_GetS<TK, TR>::before_all_runners(const Input_para& inp, UnitCell& ucell)
29+
{
30+
ModuleBase::TITLE("ESolver_GetS", "before_all_runners");
31+
ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");
32+
33+
// 1.1) read pseudopotentials
34+
ucell.read_pseudo(GlobalV::ofs_running);
35+
36+
// 1.2) symmetrize things
37+
if (ModuleSymmetry::Symmetry::symm_flag == 1)
38+
{
39+
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
40+
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
41+
}
42+
43+
// 1.3) Setup k-points according to symmetry.
44+
this->kv.set(ucell.symm, inp.kpoint_file, inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
45+
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
46+
47+
ModuleIO::setup_parameters(ucell, this->kv);
48+
49+
// 2) init ElecState
50+
// autoset nbands in ElecState, it should before basis_init (for Psi 2d division)
51+
if (this->pelec == nullptr)
52+
{
53+
// TK stands for double and complex<double>?
54+
this->pelec = new elecstate::ElecStateLCAO<TK>(&(this->chr), // use which parameter?
55+
&(this->kv),
56+
this->kv.get_nks(),
57+
&(this->GG), // mohan add 2024-04-01
58+
&(this->GK), // mohan add 2024-04-01
59+
this->pw_rho,
60+
this->pw_big);
61+
}
62+
63+
// 3) init LCAO basis
64+
// reading the localized orbitals/projectors
65+
// construct the interpolation tables.
66+
LCAO_domain::init_basis_lcao(this->pv,
67+
inp.onsite_radius,
68+
inp.lcao_ecut,
69+
inp.lcao_dk,
70+
inp.lcao_dr,
71+
inp.lcao_rmax,
72+
ucell,
73+
two_center_bundle_,
74+
orb_);
75+
76+
// 4) initialize the density matrix
77+
// DensityMatrix is allocated here, DMK is also initialized here
78+
// DMR is not initialized here, it will be constructed in each before_scf
79+
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), inp.nspin);
80+
81+
ModuleBase::timer::tick("ESolver_GetS", "before_all_runners");
82+
}
83+
84+
template <>
85+
void ESolver_GetS<double, double>::runner(const int istep, UnitCell& ucell)
86+
{
87+
ModuleBase::TITLE("ESolver_GetS", "runner");
88+
ModuleBase::WARNING_QUIT("ESolver_GetS<double, double>::runner", "not implemented");
89+
}
90+
91+
template <>
92+
void ESolver_GetS<std::complex<double>, std::complex<double>>::runner(const int istep, UnitCell& ucell)
93+
{
94+
ModuleBase::TITLE("ESolver_GetS", "runner");
95+
ModuleBase::timer::tick("ESolver_GetS", "runner");
96+
97+
// (1) Find adjacent atoms for each atom.
98+
double search_radius = -1.0;
99+
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
100+
PARAM.inp.out_level,
101+
orb_.get_rcutmax_Phi(),
102+
ucell.infoNL.get_rcutmax_Beta(),
103+
PARAM.globalv.gamma_only_local);
104+
105+
atom_arrange::search(PARAM.inp.search_pbc,
106+
GlobalV::ofs_running,
107+
GlobalC::GridD,
108+
ucell,
109+
search_radius,
110+
PARAM.inp.test_atom_input);
111+
112+
this->RA.for_2d(this->pv, PARAM.globalv.gamma_only_local, orb_.cutoffs());
113+
114+
if (this->p_hamilt == nullptr)
115+
{
116+
this->p_hamilt
117+
= new hamilt::HamiltLCAO<std::complex<double>, std::complex<double>>(&this->pv,
118+
this->kv,
119+
*(two_center_bundle_.overlap_orb),
120+
orb_.cutoffs());
121+
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, std::complex<double>>*>(this->p_hamilt->ops)
122+
->contributeHR();
123+
}
124+
125+
const std::string fn = PARAM.globalv.global_out_dir + "SR.csr";
126+
std::cout << " The file is saved in " << fn << std::endl;
127+
ModuleIO::output_SR(pv, GlobalC::GridD, this->p_hamilt, fn);
128+
129+
ModuleBase::timer::tick("ESolver_GetS", "runner");
130+
}
131+
132+
template <>
133+
void ESolver_GetS<std::complex<double>, double>::runner(const int istep, UnitCell& ucell)
134+
{
135+
ModuleBase::TITLE("ESolver_GetS", "runner");
136+
ModuleBase::timer::tick("ESolver_GetS", "runner");
137+
138+
// (1) Find adjacent atoms for each atom.
139+
double search_radius = -1.0;
140+
search_radius = atom_arrange::set_sr_NL(GlobalV::ofs_running,
141+
PARAM.inp.out_level,
142+
orb_.get_rcutmax_Phi(),
143+
ucell.infoNL.get_rcutmax_Beta(),
144+
PARAM.globalv.gamma_only_local);
145+
146+
atom_arrange::search(PARAM.inp.search_pbc,
147+
GlobalV::ofs_running,
148+
GlobalC::GridD,
149+
ucell,
150+
search_radius,
151+
PARAM.inp.test_atom_input);
152+
153+
this->RA.for_2d(this->pv, PARAM.globalv.gamma_only_local, orb_.cutoffs());
154+
155+
if (this->p_hamilt == nullptr)
156+
{
157+
this->p_hamilt = new hamilt::HamiltLCAO<std::complex<double>, double>(&this->pv,
158+
this->kv,
159+
*(two_center_bundle_.overlap_orb),
160+
orb_.cutoffs());
161+
dynamic_cast<hamilt::OperatorLCAO<std::complex<double>, double>*>(this->p_hamilt->ops)->contributeHR();
162+
}
163+
164+
const std::string fn = PARAM.globalv.global_out_dir + "SR.csr";
165+
std::cout << " The file is saved in " << fn << std::endl;
166+
ModuleIO::output_SR(pv, GlobalC::GridD, this->p_hamilt, fn);
167+
168+
ModuleBase::timer::tick("ESolver_GetS", "runner");
169+
}
170+
171+
template class ESolver_GetS<double, double>;
172+
template class ESolver_GetS<std::complex<double>, double>;
173+
template class ESolver_GetS<std::complex<double>, std::complex<double>>;
174+
175+
} // namespace ModuleESolver
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef ESOLVER_GETS_H
2+
#define ESOLVER_GETS_H
3+
4+
#include "module_basis/module_nao/two_center_bundle.h"
5+
#include "module_cell/unitcell.h"
6+
#include "module_esolver/esolver_ks.h"
7+
#include "module_hamilt_lcao/module_gint/gint_gamma.h"
8+
#include "module_hamilt_lcao/module_gint/gint_k.h"
9+
10+
#include <memory>
11+
12+
namespace ModuleESolver
13+
{
14+
template <typename TK, typename TR>
15+
class ESolver_GetS : public ESolver_KS<TK>
16+
{
17+
public:
18+
ESolver_GetS();
19+
~ESolver_GetS();
20+
21+
void before_all_runners(const Input_para& inp, UnitCell& ucell) override;
22+
23+
void after_all_runners() {};
24+
25+
void runner(const int istep, UnitCell& ucell) override;
26+
27+
//! calculate total energy of a given system
28+
double cal_energy() {};
29+
30+
//! calcualte forces for the atoms in the given cell
31+
void cal_force(ModuleBase::matrix& force) {};
32+
33+
//! calcualte stress of given cell
34+
void cal_stress(ModuleBase::matrix& stress) {};
35+
36+
protected:
37+
// we will get rid of this class soon, don't use it, mohan 2024-03-28
38+
Record_adj RA;
39+
40+
// 2d block - cyclic distribution info
41+
Parallel_Orbitals pv;
42+
43+
// used for k-dependent grid integration.
44+
Gint_k GK;
45+
46+
// used for gamma only algorithms.
47+
Gint_Gamma GG;
48+
49+
TwoCenterBundle two_center_bundle_;
50+
51+
// // temporary introduced during removing GlobalC::ORB
52+
LCAO_Orbitals orb_;
53+
};
54+
} // namespace ModuleESolver
55+
#endif

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -122,30 +122,7 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
122122
ModuleBase::TITLE("ESolver_KS_LCAO", "before_all_runners");
123123
ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners");
124124

125-
// 1) calculate overlap matrix S
126-
if (PARAM.inp.calculation == "get_S")
127-
{
128-
// 1.1) read pseudopotentials
129-
ucell.read_pseudo(GlobalV::ofs_running);
130-
131-
// 1.2) symmetrize things
132-
if (ModuleSymmetry::Symmetry::symm_flag == 1)
133-
{
134-
ucell.symm.analy_sys(ucell.lat, ucell.st, ucell.atoms, GlobalV::ofs_running);
135-
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "SYMMETRY");
136-
}
137-
138-
// 1.3) Setup k-points according to symmetry.
139-
this->kv.set(ucell.symm, PARAM.inp.kpoint_file, PARAM.inp.nspin, ucell.G, ucell.latvec, GlobalV::ofs_running);
140-
ModuleBase::GlobalFunc::DONE(GlobalV::ofs_running, "INIT K-POINTS");
141-
142-
ModuleIO::setup_parameters(ucell, this->kv);
143-
}
144-
else
145-
{
146-
// 1) else, call before_all_runners() in ESolver_KS
147-
ESolver_KS<TK>::before_all_runners(inp, ucell);
148-
} // end ifnot get_S
125+
ESolver_KS<TK>::before_all_runners(inp, ucell);
149126

150127
// 2) init ElecState
151128
// autoset nbands in ElecState, it should before basis_init (for Psi 2d division)
@@ -179,13 +156,6 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
179156
// DMR is not initialized here, it will be constructed in each before_scf
180157
dynamic_cast<elecstate::ElecStateLCAO<TK>*>(this->pelec)->init_DM(&this->kv, &(this->pv), PARAM.inp.nspin);
181158

182-
// this function should be removed outside of the function in near future
183-
if (PARAM.inp.calculation == "get_S")
184-
{
185-
ModuleBase::timer::tick("ESolver_KS_LCAO", "init");
186-
return;
187-
}
188-
189159
// 5) initialize Hamilt in LCAO
190160
// * allocate H and S matrices according to computational resources
191161
// * set the 'trace' between local H/S and global H/S

source/module_esolver/esolver_ks_lcao.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
3939

4040
void after_all_runners() override;
4141

42-
void get_S();
43-
4442
protected:
4543
virtual void before_scf(const int istep) override;
4644

0 commit comments

Comments
 (0)