Skip to content

Commit 4990f99

Browse files
author
Kai Luo
committed
Add KSStateToVariable method for handling occupation numbers and wavefunctions in RDMFT_LCAO
1 parent 94bad6a commit 4990f99

File tree

4 files changed

+226
-5
lines changed

4 files changed

+226
-5
lines changed

source/source_esolver/directmin_problems/rdmft_lcao.cpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,106 @@ void RDMFT_LCAO<TK, TR>::VariableToOccNumWfc(const Variable& x,
305305
}
306306
}
307307

308+
template <typename TK, typename TR>
309+
void RDMFT_LCAO<TK, TR>::KSStateToVariable(const ModuleBase::matrix& occ_num,
310+
const psi::Psi<TK>& wfc,
311+
Variable& x) const
312+
{
313+
const int required_components = this->nk_total;
314+
const int expected_elements = required_components * 2;
315+
if (x.Getnumofelements() != expected_elements)
316+
{
317+
throw std::runtime_error("KSStateToVariable received a variable with unexpected component count");
318+
}
319+
320+
if (occ_num.nr != this->nk_total || occ_num.nc != this->nbands_total)
321+
{
322+
throw std::runtime_error("KSStateToVariable received occupation matrix with inconsistent shape");
323+
}
324+
325+
const int occ_offset = required_components;
326+
for (int ik = 0; ik < this->nk_total; ++ik)
327+
{
328+
Element& occ_elem = x.GetElement(occ_offset + ik);
329+
if (occ_elem.Getrow() != this->nbands_total || occ_elem.Getcol() != 1 || occ_elem.Getiscomplex())
330+
{
331+
throw std::runtime_error("KSStateToVariable encountered an occupation component with unexpected shape");
332+
}
333+
334+
double* occ_buffer = occ_elem.ObtainWriteEntireData();
335+
for (int ib = 0; ib < this->nbands_total; ++ib)
336+
{
337+
occ_buffer[ib] = occ_num(ik, ib);
338+
}
339+
}
340+
341+
const int local_basis_count = this->pv.get_row_size();
342+
const int local_band_slots = this->pv.get_col_size();
343+
const int local_band_count = this->pv.ncol_bands;
344+
const int band_limit = std::min(local_band_count, local_band_slots);
345+
346+
if (wfc.get_nk() != this->nk_total || wfc.get_nbands() != local_band_count || wfc.get_nbasis() != local_basis_count)
347+
{
348+
throw std::runtime_error("KSStateToVariable received wavefunction object with incompatible dimensions");
349+
}
350+
351+
for (int ik = 0; ik < required_components; ++ik)
352+
{
353+
Element& wf_elem = x.GetElement(ik);
354+
if (wf_elem.Getrow() != this->nbasis_total || wf_elem.Getcol() != this->nbands_total)
355+
{
356+
throw std::runtime_error("KSStateToVariable encountered a wavefunction block with unexpected global shape");
357+
}
358+
359+
double* raw_data = wf_elem.ObtainWriteEntireData();
360+
const bool is_complex = wf_elem.Getiscomplex();
361+
const int element_length = wf_elem.Getlength();
362+
std::fill(raw_data, raw_data + element_length, 0.0);
363+
std::complex<double>* complex_buffer = reinterpret_cast<std::complex<double>*>(raw_data);
364+
365+
wfc.fix_k(ik);
366+
367+
for (int ib_local = 0; ib_local < band_limit; ++ib_local)
368+
{
369+
const int band_global = this->pv.local2global_col(ib_local);
370+
if (band_global < 0 || band_global >= this->nbands_total)
371+
{
372+
continue;
373+
}
374+
375+
wfc.fix_b(ib_local);
376+
377+
for (int ir_local = 0; ir_local < local_basis_count; ++ir_local)
378+
{
379+
const int basis_global = this->pv.local2global_row(ir_local);
380+
if (basis_global < 0 || basis_global >= this->nbasis_total)
381+
{
382+
continue;
383+
}
384+
385+
const int linear_index = basis_global + this->nbasis_total * band_global;
386+
const TK value = wfc(ir_local);
387+
388+
if (is_complex)
389+
{
390+
complex_buffer[linear_index] = std::complex<double>(std::real(value), std::imag(value));
391+
}
392+
else
393+
{
394+
raw_data[linear_index] = static_cast<double>(std::real(value));
395+
}
396+
}
397+
}
398+
}
399+
400+
if (this->ucell_ref_ == nullptr)
401+
{
402+
throw std::runtime_error("KSStateToVariable requires a valid UnitCell reference");
403+
}
404+
405+
this->rdmft_solver.update_elec(*this->ucell_ref_, occ_num, wfc);
406+
}
407+
308408

309409

310410

source/source_esolver/directmin_problems/rdmft_lcao.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class RDMFT_LCAO : public Problem, public ESolver_KS_LCAO<TK, TR>
8888
ModuleBase::matrix &occ_num,
8989
psi::Psi<TK> &wfc) const;
9090

91+
void KSStateToVariable(const ModuleBase::matrix& occ_num,
92+
const psi::Psi<TK>& wfc,
93+
Variable& x) const;
94+
9195

9296
};
9397

source/source_esolver/esolver_directmin_lcao.cpp

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "Manifolds/Euclidean.h"
1616
#include "Manifolds/CStiefel.h"
1717

18+
#include <string>
19+
1820
// using namespace ROPTLITE;
1921

2022
// #include "source_lcao/module_rdmft/rdmft.h"
@@ -46,6 +48,25 @@
4648
#include "source_lcao/module_ri/exx_opt_orb.h"
4749
#endif
4850

51+
namespace
52+
{
53+
template <typename TK, typename TR>
54+
bool seed_variable_from_ks_solver(ModuleESolver::ESolver_KS_LCAO<TK, TR>* ks_solver,
55+
ModuleESolver::RDMFT_LCAO<TK, TR>* rdmft_prob,
56+
ROPTLITE::Variable& variable)
57+
{
58+
if (ks_solver == nullptr || rdmft_prob == nullptr)
59+
{
60+
return false;
61+
}
62+
63+
rdmft_prob->KSStateToVariable(ks_solver->get_elecstate().wg,
64+
ks_solver->get_wavefunctions(),
65+
variable);
66+
return true;
67+
}
68+
} // namespace
69+
4970
namespace ModuleESolver
5071
{
5172

@@ -361,14 +382,73 @@ void ESolver_DirectMin_LCAO::setup_solver(UnitCell& ucell, const Input_para& inp
361382
// for the moment, use random use random orbitals and fixed weights
362383
else if (inp.directmin_objective == "rdmft")
363384
{
364-
X = this->prob->GetDomain()->RandominManifold();
365-
// check if directmin_init_method is "ks", then we need to get the KS wavefunctions and occupations
366-
// (was rdmft_init_method)
367-
// from the previous SCF calculation
385+
Manifold* domain = this->prob->GetDomain();
386+
if (domain == nullptr)
387+
{
388+
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "RDMFT problem domain is undefined");
389+
}
390+
391+
Variable initial_variable = domain->RandominManifold();
392+
bool seeded = false;
393+
std::string seed_error;
394+
368395
if (this->init_method_ == "ks")
369396
{
370-
397+
ModuleESolver::ESolver* p_esolver = nullptr;
398+
if (PARAM.globalv.gamma_only_local)
399+
{
400+
p_esolver = new ESolver_KS_LCAO<double, double>();
401+
}
402+
else if (PARAM.inp.nspin < 4)
403+
{
404+
p_esolver = new ESolver_KS_LCAO<std::complex<double>, double>();
405+
}
406+
else
407+
{
408+
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
409+
}
410+
411+
try
412+
{
413+
p_esolver->before_all_runners(ucell, inp);
414+
p_esolver->runner(ucell, 0);
415+
416+
seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
417+
dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
418+
initial_variable)
419+
|| seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
420+
dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
421+
initial_variable)
422+
|| seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
423+
dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
424+
initial_variable);
425+
}
426+
catch (const std::exception& ex)
427+
{
428+
seed_error = ex.what();
429+
}
430+
431+
ModuleESolver::clean_esolver(p_esolver, false);
432+
433+
if (!seeded)
434+
{
435+
if (!seed_error.empty())
436+
{
437+
const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
438+
ModuleBase::WARNING("ESolver_DirectMin_LCAO", message.c_str());
439+
}
440+
else
441+
{
442+
ModuleBase::WARNING("ESolver_DirectMin_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
443+
}
444+
}
371445
}
446+
else
447+
{
448+
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Only 'ks' is supported for directmin_init_method currently");
449+
}
450+
451+
X = std::move(initial_variable);
372452
}
373453

374454
// next set up the solver based on directmin_solver, sd, cg, bfgs, etc.

source/source_esolver/esolver_ks_lcao.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define ESOLVER_KS_LCAO_H
33

44
#include "esolver_ks.h"
5+
#include <stdexcept>
56

67
// for adjacent atoms
78
#include "source_lcao/record_adj.h"
@@ -62,6 +63,42 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
6263

6364
void after_all_runners(UnitCell& ucell) override;
6465

66+
const psi::Psi<TK>& get_wavefunctions() const
67+
{
68+
if (this->psi == nullptr)
69+
{
70+
throw std::runtime_error("ESolver_KS_LCAO wavefunctions are not initialized");
71+
}
72+
return *(this->psi);
73+
}
74+
75+
psi::Psi<TK>& get_wavefunctions()
76+
{
77+
if (this->psi == nullptr)
78+
{
79+
throw std::runtime_error("ESolver_KS_LCAO wavefunctions are not initialized");
80+
}
81+
return *(this->psi);
82+
}
83+
84+
const elecstate::ElecState& get_elecstate() const
85+
{
86+
if (this->pelec == nullptr)
87+
{
88+
throw std::runtime_error("ESolver_KS_LCAO elecstate is not initialized");
89+
}
90+
return *(this->pelec);
91+
}
92+
93+
elecstate::ElecState& get_elecstate()
94+
{
95+
if (this->pelec == nullptr)
96+
{
97+
throw std::runtime_error("ESolver_KS_LCAO elecstate is not initialized");
98+
}
99+
return *(this->pelec);
100+
}
101+
65102
protected:
66103
virtual void before_scf(UnitCell& ucell, const int istep) override;
67104

0 commit comments

Comments
 (0)