Skip to content

Commit bf57ea6

Browse files
author
Kai Luo
committed
Refactor ESolver_RDMFT_LCAO: Remove redundant wavefunction and elecstate getters, and implement setup_initial_guess method for improved initialization
1 parent 4cf918e commit bf57ea6

File tree

3 files changed

+91
-108
lines changed

3 files changed

+91
-108
lines changed

source/source_esolver/esolver_ks_lcao.h

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,41 +63,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
6363

6464
void after_all_runners(UnitCell& ucell) override;
6565

66-
const psi::Psi<TK>& get_wavefunctions() const
67-
{
68-
if (this->psi == nullptr)
69-
{
70-
throw std::runtime_error("ESolver_KS_LCAO wavefunctions are not initialized");
71-
}
72-
return *(this->psi);
73-
}
74-
75-
psi::Psi<TK>& get_wavefunctions()
76-
{
77-
if (this->psi == nullptr)
78-
{
79-
throw std::runtime_error("ESolver_KS_LCAO wavefunctions are not initialized");
80-
}
81-
return *(this->psi);
82-
}
83-
84-
const elecstate::ElecState& get_elecstate() const
85-
{
86-
if (this->pelec == nullptr)
87-
{
88-
throw std::runtime_error("ESolver_KS_LCAO elecstate is not initialized");
89-
}
90-
return *(this->pelec);
91-
}
92-
93-
elecstate::ElecState& get_elecstate()
94-
{
95-
if (this->pelec == nullptr)
96-
{
97-
throw std::runtime_error("ESolver_KS_LCAO elecstate is not initialized");
98-
}
99-
return *(this->pelec);
100-
}
10166

10267
protected:
10368
virtual void before_scf(UnitCell& ucell, const int istep) override;
@@ -162,6 +127,12 @@ class ESolver_KS_LCAO : public ESolver_KS<TK>
162127

163128
friend class LR::ESolver_LR<double, double>;
164129
friend class LR::ESolver_LR<std::complex<double>, double>;
130+
131+
public:
132+
const Parallel_Orbitals& get_pv() const { return pv; }
133+
int get_nk() const { return this->kv.get_nks(); } // if pv.nk exists and is public
134+
int get_ncol() const { return pv.ncol; }
135+
int get_nrow() const { return pv.nrow; }
165136
};
166137
} // namespace ModuleESolver
167138
#endif

source/source_esolver/esolver_rdmft_lcao.cpp

Lines changed: 83 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -365,86 +365,23 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_solver(UnitCell& ucell, const Input_para&
365365
// for the moment, use random use random orbitals and fixed weights
366366
else if (inp.directmin_objective == "rdmft")
367367
{
368-
// Manifold* domain = this->prob->GetDomain();
369-
// if (domain == nullptr)
370-
// {
371-
// ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "RDMFT problem domain is undefined");
372-
// }
373-
374-
Variable initial_variable = this->prob->GetDomain()->RandominManifold();
375-
bool seeded = false;
376-
std::string seed_error;
377-
378368
if (this->init_method_ == "ks")
379369
{
380-
ModuleESolver::ESolver* p_esolver = nullptr;
381-
if (PARAM.globalv.gamma_only_local)
382-
{
383-
p_esolver = new ESolver_KS_LCAO<double, double>();
384-
}
385-
else if (PARAM.inp.nspin < 4)
386-
{
387-
p_esolver = new ESolver_KS_LCAO<std::complex<double>, double>();
388-
}
389-
else
390-
{
391-
p_esolver = new ESolver_KS_LCAO<std::complex<double>, std::complex<double>>();
392-
}
393-
394-
try
395-
{
396-
// temporarily use esolver_type=ksdft
397-
p_esolver->before_all_runners(ucell, inp);
398-
399-
std::cout << "Setting up KS solver for seeding DirectMin variable" << std::endl;
400-
p_esolver->runner(ucell, 0);
401-
// check if the runner was successful
402-
if (!p_esolver->conv_esolver )
403-
{
404-
throw std::runtime_error("KS solver did not converge");
405-
}
406-
else
407-
{
408-
std::cout << "KS solver converged successfully." << std::endl;
409-
}
410-
411-
// seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
412-
// dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
413-
// initial_variable)
414-
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
415-
// dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
416-
// initial_variable)
417-
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
418-
// dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
419-
// initial_variable);
420-
}
421-
catch (const std::exception& ex)
422-
{
423-
seed_error = ex.what();
424-
}
425-
426-
ModuleESolver::clean_esolver(p_esolver, false);
427-
428-
if (!seeded)
429-
{
430-
if (!seed_error.empty())
431-
{
432-
const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
433-
ModuleBase::WARNING("ESolver_RDMFT_LCAO", message.c_str());
434-
}
435-
else
436-
{
437-
ModuleBase::WARNING("ESolver_RDMFT_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
438-
}
370+
this->setup_initial_guess(ucell, inp);
371+
}
372+
else if (this->init_method_ == "random")
373+
{
374+
Manifold* space = this->prob->GetDomain();
375+
if (!space) {
376+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Manifold is not defined in RDMFT_LCAO class");
377+
return;
439378
}
379+
this->X = space->RandominManifold();
440380
}
441381
else
442382
{
443-
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Only 'ks' is supported for directmin_init_method currently");
383+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Unknown init_method: " + this->init_method_);
444384
}
445-
446-
this->X = std::move(initial_variable);
447-
this->X.Print("Initial variable for DirectMin:");
448385
}
449386

450387
// next set up the solver based on directmin_solver, sd, cg, bfgs, etc.
@@ -766,6 +703,79 @@ void ESolver_RDMFT_LCAO<TK, TR>::cleanup_problems()
766703
}
767704
}
768705

706+
template <typename TK, typename TR>
707+
void ESolver_RDMFT_LCAO<TK, TR>::setup_initial_guess(UnitCell& ucell, const Input_para& inp)
708+
{
709+
ModuleBase::TITLE(this->classname, "setup_initial_guess");
710+
711+
X = this->prob->GetDomain() -> RandominManifold();
712+
713+
ModuleESolver::ESolver_KS_LCAO<TK, TR> * p_esolver = new ESolver_KS_LCAO<TK, TR>();
714+
p_esolver->before_all_runners(ucell, inp);
715+
p_esolver->runner(ucell, 0);
716+
717+
int nk = p_esolver->get_nk();
718+
int ncol = p_esolver->get_nrow(); // note the ncol in KS solver is nrow in RDMFT problem, nbasis
719+
int nrow = p_esolver->get_ncol(); // note the nrow in KS solver is ncol in RDMFT problem, nbands
720+
721+
if( nk != X.Getnumofelements() / 2)
722+
{
723+
std::cout << "nk from KS solver: " << nk << ", nk from RDMFT problem: " << X.Getnumofelements() / 2<< std::endl;
724+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The nk points from KS solver and RDMFT problem do not match.");
725+
}
726+
if ( ncol != X.GetElement(0).Getcol())
727+
{
728+
std::cout << "ncol from KS solver: " << ncol << ", ncol from RDMFT problem: " << X.GetElement(0).Getcol()<< std::endl;
729+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The number of bands from KS solver and RDMFT problem do not match.");
730+
}
731+
if ( nrow != X.GetElement(0).Getrow())
732+
{
733+
std::cout << "nrow from KS solver: " << nrow << ", nrow from RDMFT problem: " << X.GetElement(0).Getrow()<< std::endl;
734+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The number of rows from KS solver and RDMFT problem do not match.");
735+
}
736+
737+
// seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
738+
// dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
739+
// initial_variable)
740+
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
741+
// dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
742+
// initial_variable)
743+
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
744+
// dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
745+
// initial_variable);
746+
// }
747+
// catch (const std::exception& ex)
748+
// {
749+
// seed_error = ex.what();
750+
// }
751+
752+
// ModuleESolver::clean_esolver(p_esolver, false);
753+
754+
// if (!seeded)
755+
// {
756+
// if (!seed_error.empty())
757+
// {
758+
// const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
759+
// ModuleBase::WARNING("ESolver_RDMFT_LCAO", message.c_str());
760+
// }
761+
// else
762+
// {
763+
// ModuleBase::WARNING("ESolver_RDMFT_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
764+
// }
765+
// }
766+
// }
767+
// else
768+
// {
769+
// ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Only 'ks' is supported for directmin_init_method currently");
770+
// }
771+
772+
// this->X = std::move(initial_variable);
773+
this->X.Print("Initial variable for DirectMin:");
774+
// }
775+
// Implementation of initial guess setup
776+
}
777+
778+
769779
// Add interface methods to access ABACUS data structures
770780
// const psi::Psi<double>* ESolver_RDMFT_LCAO<TK, TR>::get_psi() const
771781
// {

source/source_esolver/esolver_rdmft_lcao.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class ESolver_RDMFT_LCAO : public ESolver_KS_LCAO<TK, TR>
104104
// Helper functions
105105
void cleanup_solvers();
106106
void cleanup_problems();
107+
108+
void setup_initial_guess(UnitCell& ucell, const Input_para& inp);
107109
};
108110

109111

0 commit comments

Comments
 (0)