Skip to content

Commit 49742a7

Browse files
author
Kai Luo
committed
Refactor: update rdmft solver initialization and pointer usage in LCAO classes
1 parent 453c23d commit 49742a7

File tree

4 files changed

+44
-86
lines changed

4 files changed

+44
-86
lines changed

source/source_esolver/directmin_problems/rdmft_lcao.cpp

Lines changed: 38 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22

33
// #include "source_lcao/module_rdmft/rdmft_tools.h"
44

5-
#include "source_base/timer.h"
6-
#include "source_base/tool_quit.h"
5+
76

87
#include <algorithm>
98
#include <complex>
109
#include <stdexcept>
1110
#include <type_traits>
1211

1312
#ifdef __EXX
14-
#include "../source_lcao/module_ri/exx_opt_orb.h"
13+
#include "source_lcao/module_ri/exx_opt_orb.h"
1514
#endif
1615

1716
using namespace ROPTLITE;
@@ -37,52 +36,42 @@ namespace ModuleESolver
3736
template <typename TK, typename TR>
3837
RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
3938
{
40-
// this->classname = "RDMFT_LCAO";
41-
// this->basisname = "LCAO";
39+
this->classname = "RDMFT_LCAO";
40+
this->basisname = "LCAO";
41+
this->exx_nao.init(); // mohan add 20251008
4242

4343
this->ucell_ref_ = &ucell;
4444

45-
#ifdef __EXX
46-
// 1. currently this initialization must be put in constructor rather than `before_all_runners()`
47-
// because the latter is not reused by ESolver_LCAO_TDDFT,
48-
// which cause the failure of the subsequent procedure reused by ESolver_LCAO_TDDFT
49-
// 2. always construct but only initialize when if(cal_exx) is true
50-
// because some members like two_level_step are used outside if(cal_exx)
51-
if (GlobalC::exx_info.info_ri.real_number)
52-
{
53-
this->exd = std::make_shared<Exx_LRI_Interface<TK, double>>(GlobalC::exx_info.info_ri);
54-
}
55-
else
56-
{
57-
this->exc = std::make_shared<Exx_LRI_Interface<TK, std::complex<double>>>(GlobalC::exx_info.info_ri);
58-
}
59-
#endif
6045

61-
ESolver_KS_LCAO<TK, TR>::before_all_runners(ucell, inp);
46+
ESolver_KS_LCAO<TK, TR>::before_all_runners(ucell, inp);
47+
48+
// member rdmft_solver called init() in before_all_runners()
49+
50+
rdmft_ptr = dynamic_cast<rdmft::RDMFT<TK, TR>*>(&(this->rdmft_solver));
6251

6352
// initialize rdmft
64-
rdmft_solver.init(this->GG,
65-
this->GK,
66-
this->pv,
67-
ucell,
68-
this->gd,
69-
this->kv,
70-
*(this->pelec),
71-
this->orb_,
72-
this->two_center_bundle_,
73-
inp.dft_functional,
74-
inp.rdmft_power_alpha);
75-
// call update_ion for rdmft_solver;
53+
// this->rdmft_ptr.init(this->GG,
54+
// this->GK,
55+
// this->pv,
56+
// ucell,
57+
// this->gd,
58+
// this->kv,
59+
// *(this->pelec),
60+
// this->orb_,
61+
// this->two_center_bundle_,
62+
// inp.dft_functional,
63+
// inp.rdmft_power_alpha);
64+
// call update_ion for rdmft_ptr;
7665
ESolver_KS_LCAO<TK, TR>::before_scf(ucell, 0);
77-
rdmft_solver.update_ion(ucell, *(this->pw_rho), this->locpp.vloc, this->sf.strucFac);
66+
rdmft_ptr->update_ion(ucell, *(this->pw_rho), this->locpp.vloc, this->sf.strucFac);
7867

79-
rdmft_solver.cal_Energy();
68+
rdmft_ptr->cal_Energy();
8069

81-
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
82-
std::cout << "Called f() in RDMFT_LCAO, E[0] = " << this->rdmft_solver.E_RDMFT[0] << std::endl;
83-
std::cout << "Called f() in RDMFT_LCAO, E[1] = " << this->rdmft_solver.E_RDMFT[1] << std::endl;
84-
std::cout << "Called f() in RDMFT_LCAO, E[2] = " << this->rdmft_solver.E_RDMFT[2] << std::endl;
85-
std::cout << "Called f() in RDMFT_LCAO, E[3] = " << this->rdmft_solver.E_RDMFT[3] << std::endl;
70+
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_ptr->Etotal << std::endl;
71+
std::cout << "Called f() in RDMFT_LCAO, E[0] = " << this->rdmft_ptr->E_RDMFT[0] << std::endl;
72+
std::cout << "Called f() in RDMFT_LCAO, E[1] = " << this->rdmft_ptr->E_RDMFT[1] << std::endl;
73+
std::cout << "Called f() in RDMFT_LCAO, E[2] = " << this->rdmft_ptr->E_RDMFT[2] << std::endl;
74+
std::cout << "Called f() in RDMFT_LCAO, E[3] = " << this->rdmft_ptr->E_RDMFT[3] << std::endl;
8675

8776
// const integer nbands = static_cast<integer>(inp.nbands > 0 ? inp.nbands : 1);
8877
// const integer stiefel_n = nbands;
@@ -227,19 +216,19 @@ realdp RDMFT_LCAO<TK, TR>::f(const Variable& x) const
227216
std::cout << std::endl;
228217
}
229218

230-
this->rdmft_solver.update_elec(*(this->ucell_ref_), occ_num, wfc);
219+
this->rdmft_ptr->update_elec(*(this->ucell_ref_), occ_num, wfc);
231220

232221
// std::cout << "after update_elec" << std::endl;
233222

234-
this->rdmft_solver.cal_Energy(1); // 1 means not to calculate forces here
223+
this->rdmft_ptr->cal_Energy(1); // 1 means not to calculate forces here
235224

236-
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
237-
std::cout << "Called f() in RDMFT_LCAO, E[0] = " << this->rdmft_solver.E_RDMFT[0] << std::endl;
238-
std::cout << "Called f() in RDMFT_LCAO, E[1] = " << this->rdmft_solver.E_RDMFT[1] << std::endl;
239-
std::cout << "Called f() in RDMFT_LCAO, E[2] = " << this->rdmft_solver.E_RDMFT[2] << std::endl;
240-
std::cout << "Called f() in RDMFT_LCAO, E[3] = " << this->rdmft_solver.E_RDMFT[3] << std::endl;
225+
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_ptr->Etotal << std::endl;
226+
std::cout << "Called f() in RDMFT_LCAO, E[0] = " << this->rdmft_ptr->E_RDMFT[0] << std::endl;
227+
std::cout << "Called f() in RDMFT_LCAO, E[1] = " << this->rdmft_ptr->E_RDMFT[1] << std::endl;
228+
std::cout << "Called f() in RDMFT_LCAO, E[2] = " << this->rdmft_ptr->E_RDMFT[2] << std::endl;
229+
std::cout << "Called f() in RDMFT_LCAO, E[3] = " << this->rdmft_ptr->E_RDMFT[3] << std::endl;
241230

242-
return this->rdmft_solver.Etotal;
231+
return this->rdmft_ptr->Etotal;
243232
}
244233

245234
template <typename TK, typename TR>
@@ -467,7 +456,7 @@ void RDMFT_LCAO<TK, TR>::KSStateToVariable(const ModuleBase::matrix& occ_num,
467456
throw std::runtime_error("KSStateToVariable requires a valid UnitCell reference");
468457
}
469458

470-
this->rdmft_solver.update_elec(*this->ucell_ref_, occ_num, wfc);
459+
this->rdmft_ptr->update_elec(*this->ucell_ref_, occ_num, wfc);
471460
}
472461

473462
// Explicit template instantiations for the typical scalar combinations used in ABACUS.

source/source_esolver/directmin_problems/rdmft_lcao.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#ifndef RDMFT_LCAO_H
22
#define RDMFT_LCAO_H
33

4+
#include "source_base/timer.h"
5+
#include "source_base/tool_quit.h"
6+
47
// #include "Problems/CStieBrockett.h"
58
#include "Manifolds/Stiefel.h"
69
#include "Manifolds/CStiefel.h"
@@ -77,7 +80,7 @@ class RDMFT_LCAO : public Problem, public ESolver_KS_LCAO<TK, TR>
7780
Stiefel* mani_st = nullptr;
7881
Euclidean* mani_euc = nullptr;
7982

80-
mutable rdmft::RDMFT<TK, TR> rdmft_solver;
83+
rdmft::RDMFT<TK, TR> * rdmft_ptr;
8184
UnitCell* ucell_ref_ = nullptr;
8285

8386
ProductManifold* mani = nullptr;

source/source_esolver/esolver_ks_lcao.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -131,31 +131,13 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(UnitCell& ucell, const Input_pa
131131
// 15) if kpar is not divisible by nks, print a warning
132132
ModuleIO::print_kpar(this->kv.get_nks(), PARAM.globalv.kpar_lcao);
133133

134-
<<<<<<< HEAD
135-
// // 15) initialize rdmft, added by jghan
136-
// if (inp.rdmft == true)
137-
// {
138-
// rdmft_solver.init(this->GG,
139-
// this->GK,
140-
// this->pv,
141-
// ucell,
142-
// this->gd,
143-
// this->kv,
144-
// *(this->pelec),
145-
// this->orb_,
146-
// two_center_bundle_,
147-
// inp.dft_functional,
148-
// inp.rdmft_power_alpha);
149-
// }
150-
=======
151134
// 16) init rdmft, added by jghan
152135
if (inp.rdmft == true)
153136
{
154137
rdmft_solver.init(this->pv, ucell,
155138
this->gd, this->kv, *(this->pelec), this->orb_,
156139
two_center_bundle_, inp.dft_functional, inp.rdmft_power_alpha);
157140
}
158-
>>>>>>> upstream/develop
159141

160142
ModuleBase::timer::tick("ESolver_KS_LCAO", "before_all_runners");
161143
return;

source/source_esolver/esolver_rdmft_lcao.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
// for xc_functional
2727
#include "source_hamilt/module_xc/xc_functional.h"
2828
// for pw_setup
29-
#include "source_esolver/pw_setup.h"
29+
// #include "source_esolver/pw_setup.h"
3030
// for charge_mixing
3131
#include "source_estate/module_charge/charge_mixing.h"
3232
//for setup_parameter
@@ -38,7 +38,7 @@
3838
// for LCAO_domain
3939
#include "source_lcao/LCAO_domain.h"
4040
// for read_wfc_nao
41-
#include "source_io/read_wfc_lcao.h"
41+
// #include "source_io/read_wfc_lcao.h"
4242
// for fixed_weights
4343
#include "source_estate/elecstate_tools.h"
4444
// for Force_Stress_LCAO
@@ -68,22 +68,6 @@ ESolver_RDMFT_LCAO<TK, TR>::ESolver_RDMFT_LCAO()
6868

6969
// Initialize variable pointers
7070
// this->params = nullptr;
71-
72-
#ifdef __EXX
73-
// 1. currently this initialization must be put in constructor rather than `before_all_runners()`
74-
// because the latter is not reused by ESolver_LCAO_TDDFT,
75-
// which cause the failure of the subsequent procedure reused by ESolver_LCAO_TDDFT
76-
// 2. always construct but only initialize when if(cal_exx) is true
77-
// because some members like two_level_step are used outside if(cal_exx)
78-
if (GlobalC::exx_info.info_ri.real_number)
79-
{
80-
this->exd = std::make_shared<Exx_LRI_Interface<double, double>>(GlobalC::exx_info.info_ri);
81-
}
82-
else
83-
{
84-
this->exc = std::make_shared<Exx_LRI_Interface<double, std::complex<double>>>(GlobalC::exx_info.info_ri);
85-
}
86-
#endif
8771
}
8872

8973
template <typename TK, typename TR>

0 commit comments

Comments
 (0)