Skip to content

Commit 00fa1b2

Browse files
author
Kai Luo
committed
Refactor RDMFT_LCAO: Add UnitCell reference and implement VariableToOccNumWfc method for improved wavefunction handling
1 parent e1756e2 commit 00fa1b2

File tree

3 files changed

+144
-14
lines changed

3 files changed

+144
-14
lines changed

source/source_esolver/directmin_problems/rdmft_lcao.cpp

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "source_base/timer.h"
66

77
#include <complex>
8+
#include <algorithm>
9+
#include <type_traits>
810
#include <stdexcept>
911

1012
#include "source_base/tool_quit.h"
@@ -15,6 +17,21 @@
1517

1618
using namespace ROPTLITE;
1719

20+
namespace
21+
{
22+
template <typename ValueType>
23+
inline ValueType convert_complex_entry(const std::complex<double>& entry, std::true_type)
24+
{
25+
return static_cast<ValueType>(entry.real());
26+
}
27+
28+
template <typename ValueType>
29+
inline ValueType convert_complex_entry(const std::complex<double>& entry, std::false_type)
30+
{
31+
return static_cast<ValueType>(entry);
32+
}
33+
} // namespace
34+
1835
namespace ModuleESolver
1936
{
2037

@@ -24,6 +41,8 @@ RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
2441
// this->classname = "RDMFT_LCAO";
2542
// this->basisname = "LCAO";
2643

44+
this->ucell_ref_ = &ucell;
45+
2746
#ifdef __EXX
2847
// 1. currently this initialization must be put in constructor rather than `before_all_runners()`
2948
// because the latter is not reused by ESolver_LCAO_TDDFT,
@@ -138,9 +157,23 @@ RDMFT_LCAO<TK, TR>::~RDMFT_LCAO()
138157
template <typename TK, typename TR>
139158
realdp RDMFT_LCAO<TK, TR>::f(const Variable& x) const
140159
{
160+
// convert Variable x to occupation numbers and wavefunctions
161+
ModuleBase::matrix occ_num;
162+
psi::Psi<TK> wfc;
163+
164+
VariableToOccNumWfc(x, occ_num, wfc);
165+
166+
if (this->ucell_ref_ == nullptr)
167+
{
168+
ModuleBase::WARNING("RDMFT_LCAO", "UnitCell reference is unavailable before calling update_elec");
169+
throw std::runtime_error("RDMFT_LCAO::f requires a valid UnitCell reference");
170+
}
171+
172+
this->rdmft_solver.update_elec(*this->ucell_ref_, occ_num, wfc);
173+
141174
this->rdmft_solver.cal_Energy(1); // 1 means not to calculate forces here
142175

143-
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
176+
// std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
144177
return this->rdmft_solver.Etotal;
145178
}
146179

@@ -173,6 +206,105 @@ Vector& RDMFT_LCAO<TK, TR>::EucHessianEta(const Variable& x, const Vector& etax,
173206
return *result;
174207
}
175208

209+
template <typename TK, typename TR>
210+
void RDMFT_LCAO<TK, TR>::VariableToOccNumWfc(const Variable& x,
211+
ModuleBase::matrix& occ_num,
212+
psi::Psi<TK>& wfc) const
213+
{
214+
const int required_components = this->nk_total;
215+
const int expected_elements = required_components * 2;
216+
if (x.Getnumofelements() != expected_elements)
217+
{
218+
throw std::runtime_error("RDMFT_LCAO::VariableToOccNumWfc received a variable with unexpected component count");
219+
}
220+
221+
if (occ_num.nr != this->nk_total || occ_num.nc != this->nbands_total)
222+
{
223+
occ_num.create(this->nk_total, this->nbands_total, false);
224+
}
225+
226+
const int local_basis_count = this->pv.get_row_size();
227+
const int local_band_count = this->pv.ncol_bands;
228+
const int available_band_slots = this->pv.get_col_size();
229+
const int band_limit = std::min(local_band_count, available_band_slots);
230+
231+
if (wfc.get_nk() != this->nk_total || wfc.get_nbands() != local_band_count || wfc.get_nbasis() != local_basis_count)
232+
{
233+
wfc.resize(this->nk_total, local_band_count, local_basis_count);
234+
}
235+
236+
// Occupation numbers occupy the second half of the product element.
237+
const int occ_offset = required_components;
238+
for (int ik = 0; ik < this->nk_total; ++ik)
239+
{
240+
const Element& occ_elem = x.GetElement(occ_offset + ik);
241+
if (occ_elem.Getrow() != this->nbands_total || occ_elem.Getcol() != 1)
242+
{
243+
throw std::runtime_error("RDMFT_LCAO::VariableToOccNumWfc encountered an occupation component with unexpected shape");
244+
}
245+
246+
const double* occ_values = occ_elem.ObtainReadData();
247+
for (int ib = 0; ib < this->nbands_total; ++ib)
248+
{
249+
occ_num(ik, ib) = occ_values[ib];
250+
}
251+
}
252+
253+
// Wavefunction blocks sit in the first half of the product element.
254+
for (int ik = 0; ik < this->nk_total; ++ik)
255+
{
256+
const Element& wf_elem = x.GetElement(ik);
257+
const int global_rows = wf_elem.Getrow();
258+
const int global_cols = wf_elem.Getcol();
259+
260+
if (global_rows != this->nbasis_total || global_cols != this->nbands_total)
261+
{
262+
throw std::runtime_error("RDMFT_LCAO::VariableToOccNumWfc encountered a wavefunction block with unexpected shape");
263+
}
264+
265+
const double* raw_data = wf_elem.ObtainReadData();
266+
const bool is_complex = wf_elem.Getiscomplex();
267+
268+
wfc.fix_k(ik);
269+
270+
for (int ib_local = 0; ib_local < band_limit; ++ib_local)
271+
{
272+
const int band_global = this->pv.local2global_col(ib_local);
273+
if (band_global < 0 || band_global >= global_cols)
274+
{
275+
continue;
276+
}
277+
278+
wfc.fix_b(ib_local);
279+
280+
for (int ir_local = 0; ir_local < local_basis_count; ++ir_local)
281+
{
282+
const int basis_global = this->pv.local2global_row(ir_local);
283+
if (basis_global < 0 || basis_global >= global_rows)
284+
{
285+
continue;
286+
}
287+
288+
const int linear_index = basis_global + global_rows * band_global;
289+
290+
TK value;
291+
if (is_complex)
292+
{
293+
const auto* complex_data = reinterpret_cast<const std::complex<double>*>(raw_data);
294+
const std::complex<double> entry = complex_data[linear_index];
295+
value = convert_complex_entry<TK>(entry, typename std::is_same<TK, double>::type());
296+
}
297+
else
298+
{
299+
value = static_cast<TK>(raw_data[linear_index]);
300+
}
301+
302+
wfc(ir_local) = value;
303+
}
304+
}
305+
}
306+
}
307+
176308

177309

178310

source/source_esolver/directmin_problems/rdmft_lcao.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ class RDMFT_LCAO : public Problem, public ESolver_KS_LCAO<TK, TR>
7777
Stiefel* mani_st = nullptr;
7878
Euclidean* mani_euc = nullptr;
7979

80-
81-
82-
rdmft::RDMFT<TK, TR> rdmft_solver;
80+
mutable rdmft::RDMFT<TK, TR> rdmft_solver;
81+
UnitCell* ucell_ref_ = nullptr;
8382

8483
ProductManifold* mani = nullptr;
8584
// const TwoCenterBundle* two_center_bundle = nullptr;
8685

86+
// convert a Variable X to occupation numbers and wavefunctions
87+
void VariableToOccNumWfc(const Variable &x,
88+
ModuleBase::matrix &occ_num,
89+
psi::Psi<TK> &wfc) const;
90+
91+
8792
};
8893

8994
} // namespace ModuleESolver

source/source_esolver/esolver_directmin_lcao.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,15 @@ void ESolver_DirectMin_LCAO::before_all_runners(UnitCell& ucell, const Input_par
9898
ModuleBase::TITLE(this->classname, "before_all_runners");
9999
ModuleBase::timer::tick(this->classname, "before_all_runners");
100100

101-
// first setup the problem, check inside the objective type
102-
this->setup_problem(ucell, inp);
103-
104-
105-
106-
107-
108101

109102
// Call parent class method properly
110103
// this->ESolver_KS_LCAO<double, double>::before_all_runners(ucell, inp);
111104

112105
// Initialize the DirectMin solver based on the optimization approach
113-
// this->initialize(ucell, inp);
106+
this->initialize(ucell, inp);
114107

115-
// Setup the optimization problem and solver
116-
// this->setup_problem(ucell, inp);
108+
// first setup the problem, check inside the objective type
109+
this->setup_problem(ucell, inp);
117110
this->setup_solver(ucell, inp);
118111
this->setup_params(ucell, inp);
119112

0 commit comments

Comments
 (0)