Skip to content

Commit 3ac2041

Browse files
author
Kai Luo
committed
Refactor RDMFT_LCAO and RDMFT classes: Clean up constructors, enhance initialization, and update energy calculation methods
1 parent a5f9852 commit 3ac2041

File tree

4 files changed

+102
-28
lines changed

4 files changed

+102
-28
lines changed

source/source_esolver/directmin_problems/rdmft_lcao.cpp

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,67 @@ RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
4040
}
4141
#endif
4242

43-
// ESolver_KS_LCAO<TK, TR>::before_all_runners(ucell, inp);
43+
ESolver_KS_LCAO<TK, TR>::before_all_runners(ucell, inp);
4444

4545
// initialize rdmft
46+
rdmft_solver.init(this->GG,
47+
this->GK,
48+
this->pv,
49+
ucell,
50+
this->gd,
51+
this->kv,
52+
*(this->pelec),
53+
this->orb_,
54+
this->two_center_bundle_,
55+
inp.dft_functional,
56+
inp.rdmft_power_alpha);
57+
4658

4759

60+
// const integer nbands = static_cast<integer>(inp.nbands > 0 ? inp.nbands : 1);
61+
// const integer stiefel_n = nbands;
62+
// const integer stiefel_p = nbands;
63+
64+
nbands_total = inp.nbands;
65+
nspin = inp.nspin;
66+
nk_total = this->kv.get_nks();
67+
nk_total *= nspin;
68+
69+
nbasis_total = this->psi ->get_nbasis();
70+
71+
// next set up the manifold
72+
integer stiefel_n = nbasis_total;
73+
integer stiefel_p = nbands_total;
74+
75+
integer multitude = nk_total;
76+
77+
integer numoftypes = 2; // wavefunctions and occupation numbers
78+
79+
80+
this->mani_euc = new Euclidean(nbands_total, "real");
81+
82+
// create the product manifold
83+
84+
if(inp.gamma_only)
85+
{
86+
this->mani_st = new Stiefel(stiefel_n, stiefel_p);
87+
this->mani_st->ChooseParamsSet1(); // change later
88+
this->mani = new ProductManifold(numoftypes, this->mani_st, multitude, this->mani_euc, multitude);
89+
}
90+
else
91+
{
92+
this->mani_cst = new CStiefel(stiefel_n, stiefel_p);
93+
this->mani_cst->ChooseParamsSet1(); // change later
94+
this->mani = new ProductManifold(numoftypes, this->mani_cst, multitude, this->mani_euc, multitude);
95+
}
96+
97+
// X = this->mani->RandominManifold(); // need to update it, either from KS calculation or file
4898

49-
const integer nbands = static_cast<integer>(inp.nbands > 0 ? inp.nbands : 1);
50-
const integer stiefel_n = nbands;
51-
const integer stiefel_p = nbands;
5299

53-
this->mani = new CStiefel(stiefel_n, stiefel_p);
54100
this->SetDomain(this->mani);
55-
this->mani->ChooseParamsSet4();
56-
this->mani->CheckParams();
101+
// this->mani->CheckParams();
102+
103+
// this->mani->ChooseParamsSet4();
57104
}
58105

59106
// template <typename TK, typename TR>
@@ -75,21 +122,20 @@ RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
75122
template <typename TK, typename TR>
76123
RDMFT_LCAO<TK, TR>::~RDMFT_LCAO()
77124
{
78-
if (this->mani != nullptr)
79-
{
80-
delete this->mani;
81-
this->mani = nullptr;
82-
}
125+
// if (this->mani != nullptr)
126+
// {
127+
// delete this->mani;
128+
// this->mani = nullptr;
129+
// }
83130
}
84131

85132

86133

87134
template <typename TK, typename TR>
88135
realdp RDMFT_LCAO<TK, TR>::f(const Variable& x) const
89136
{
90-
// Energy evaluation for the RDMFT problem will be implemented later.
91-
(void)x;
92-
return 0.0;
137+
this->rdmft_solver.cal_Energy(1); // 1 means not to calculate forces here
138+
return this->rdmft_solver.Etotal;
93139
}
94140

95141
template <typename TK, typename TR>
@@ -101,8 +147,8 @@ Vector& RDMFT_LCAO<TK, TR>::EucGrad(const Variable& x, Vector* result) const
101147
throw std::runtime_error("RDMFT_LCAO::EucGrad requires a valid result storage");
102148
}
103149

104-
*result = x;
105-
result->ScalarTimesThis(0.0);
150+
*result = x; // copy structure
151+
result->ScalarTimesThis(0.0); // zero until real gradient assembly implemented
106152
return *result;
107153
}
108154

@@ -127,6 +173,7 @@ Vector& RDMFT_LCAO<TK, TR>::EucHessianEta(const Variable& x, const Vector& etax,
127173

128174

129175

176+
130177
// Explicit template instantiations for the typical scalar combinations used in ABACUS.
131178
template class RDMFT_LCAO<double, double>;
132179
template class RDMFT_LCAO<std::complex<double>, double>;

source/source_esolver/directmin_problems/rdmft_lcao.h

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
#define RDMFT_LCAO_H
33

44
// #include "Problems/CStieBrockett.h"
5+
#include "Manifolds/Stiefel.h"
56
#include "Manifolds/CStiefel.h"
7+
#include "Manifolds/Euclidean.h"
8+
#include "Manifolds/MultiManifolds.h"
9+
610
#include "Problems/Problem.h"
711

812
#include <memory>
913

14+
/*
1015
#include "source_cell/unitcell.h"
1116
#include "source_io/module_parameter/input_parameter.h"
1217
18+
1319
// for k-points in Brillouin zone
1420
#include "source_cell/klist.h"
1521
// for plane-wave basis
@@ -33,6 +39,7 @@
3339
#include "source_psi/psi.h"
3440
// for Hamiltonian
3541
#include "source_hamilt/hamilt.h"
42+
*/ // Since esolver_ks_lcao.h already includes these headers
3643

3744
#include "source_esolver/esolver_ks_lcao.h"
3845

@@ -42,20 +49,40 @@ namespace ModuleESolver
4249
{
4350

4451
template <typename TK, typename TR>
45-
class RDMFT_LCAO : public Problem //, public ESolver_KS_LCAO<TK, TR>
52+
class RDMFT_LCAO : public Problem, public ESolver_KS_LCAO<TK, TR>
4653
{
4754
public:
4855
// RDMFT_LCAO();
4956
RDMFT_LCAO(UnitCell& ucell, const Input_para& inp);
5057
~RDMFT_LCAO();
5158

52-
virtual realdp f(const Variable &x) const;
59+
// Override pure virtuals from ROPTLITE::Problem (must be const signatures)
60+
realdp f(const Variable &x) const override;
61+
Vector &EucGrad(const Variable &x, Vector *result) const override;
62+
Vector &EucHessianEta(const Variable &x, const Vector &etax, Vector *result) const override;
63+
64+
65+
// some variables for the dimension of the problem, they should be set in the constructor
66+
int nbands_total = 0;
67+
int nspin = 1;
68+
int nk_total = 1;
69+
int nbasis_total= 0;
70+
71+
// int nbasis_local = 0;
5372

54-
virtual Vector &EucGrad(const Variable &x, Vector *result) const;
55-
virtual Vector &EucHessianEta(const Variable &x, const Vector &etax, Vector *result) const;
73+
// In this code, we are dealing with joint optimization of both occupation numbers and wavefunctions
74+
// So we use multimanifolds, composed of CStiefel or Stiefel manifolds for the orbitals
75+
// and Euclidean manifolds for the occupation numbers
76+
CStiefel* mani_cst = nullptr;
77+
Stiefel* mani_st = nullptr;
78+
Euclidean* mani_euc = nullptr;
79+
5680

57-
CStiefel* mani = nullptr;
5881

82+
rdmft::RDMFT<TK, TR> rdmft_solver;
83+
84+
ProductManifold* mani = nullptr;
85+
// const TwoCenterBundle* two_center_bundle = nullptr;
5986

6087
};
6188

source/source_lcao/module_rdmft/rdmft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ double RDMFT<TK, TR>::cal_E_grad_wfc_occ_num()
292292

293293
// cal_type = 2 just support XC-functional without exx
294294
template <typename TK, typename TR>
295-
void RDMFT<TK, TR>::cal_Energy(const int cal_type)
295+
void RDMFT<TK, TR>::cal_Energy(const int cal_type) const
296296
{
297297
double E_Ewald = pelec->f_en.ewald_energy;
298298
double E_entropy = pelec->f_en.demet;

source/source_lcao/module_rdmft/rdmft.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ class RDMFT
7575
psi::Psi<TK> occNum_HamiltWfc;
7676

7777
//! E_RDMFT[4] stores ETV, Ehartree, Exc, Etotal
78-
double E_RDMFT[4] = {0.0};
79-
double Etotal = 0.0;
78+
mutable double E_RDMFT[4] = {0.0};
79+
mutable double Etotal = 0.0;
8080
// std::vector<double> E_RDMFT(4);
8181

8282
//! initialization of rdmft calculation
@@ -103,7 +103,7 @@ class RDMFT
103103
//! obtain the gradient of total energy with respect to occupation number and wfc
104104
double cal_E_grad_wfc_occ_num();
105105

106-
void cal_Energy(const int cal_type = 1);
106+
void cal_Energy(const int cal_type = 1) const;
107107

108108
//! update occ_number for optimization algorithms that depend on Hamilton
109109
void update_occNumber(const ModuleBase::matrix& occ_number_in);
@@ -184,8 +184,8 @@ class RDMFT
184184
bool exx_spacegroup_symmetry = false;
185185
#endif
186186

187-
double etxc = 0.0;
188-
double vtxc = 0.0;
187+
mutable double etxc = 0.0;
188+
mutable double vtxc = 0.0;
189189
bool only_exx_type = false;
190190
const int cal_E_type = 1; // cal_type = 2 just support XC-functional without exx
191191

0 commit comments

Comments
 (0)