Skip to content

Commit d8af222

Browse files
author
Kai Luo
committed
Add inline getters for electronic state and wavefunction in ESolver_FP and ESolver_KS classes
1 parent bf57ea6 commit d8af222

File tree

3 files changed

+119
-67
lines changed

3 files changed

+119
-67
lines changed

source/source_esolver/esolver_fp.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class ESolver_FP: public ESolver
5555

5656
virtual void after_all_runners(UnitCell& ucell) override;
5757

58+
59+
inline elecstate::ElecState* get_pelec() { return pelec; }
60+
inline K_Vectors& get_kv() { return kv; }
61+
5862
protected:
5963
//! Something to do before SCF iterations.
6064
virtual void before_scf(UnitCell& ucell, const int istep);

source/source_esolver/esolver_ks.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class ESolver_KS : public ESolver_FP
3636

3737
virtual void after_all_runners(UnitCell& ucell) override;
3838

39+
inline psi::Psi<T>* get_psi() { return psi; };
40+
3941
protected:
4042
//! Something to do before SCF iterations.
4143
virtual void before_scf(UnitCell& ucell, const int istep) override;

source/source_esolver/esolver_rdmft_lcao.cpp

Lines changed: 113 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -408,19 +408,31 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_params(UnitCell& ucell, const Input_para&
408408
if (inp.directmin_solver == "cg")
409409
{
410410
static const std::map<std::string, RCGmethods> cg_methods = {
411-
{"fr", FLETCHER_REEVES},
412-
{"pr", POLAK_RIBIERE_MOD},
411+
{"fr", FLETCHER_REEVES},
412+
{"pr", POLAK_RIBIERE_MOD},
413413
{"hs", HESTENES_STIEFEL},
414-
{"frpr", FR_PR},
415-
{"dy", DAI_YUAN},
414+
{"frpr", FR_PR},
415+
{"dy", DAI_YUAN},
416416
{"hz", HAGER_ZHANG}};
417-
auto it = cg_methods.find(inp.directmin_cg_method);
418-
cg_index = (it != cg_methods.end()) ? it->second : -1;
419-
420-
// params.insert(std::pair<std::string, RCGmethods>
421-
// (std::string("RCGmethod"), static_cast<RCGmethods>(cg_index)));
422417

423-
// need to set the CG method for the solver
418+
auto it = cg_methods.find(inp.directmin_cg_method);
419+
cg_index = (it != cg_methods.end()) ? static_cast<int>(it->second) : -1;
420+
421+
std::cout << "CG index: " << cg_index << std::endl;
422+
if (cg_index != -1) {
423+
RCG* cg_solver = dynamic_cast<RCG*>(this->solver);
424+
if (cg_solver == nullptr) {
425+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO",
426+
"CG solver requested, but current solver is not RCG.");
427+
} else {
428+
std::map<std::string, double> params;
429+
params["RCGmethod"] = static_cast<double>(cg_index);
430+
cg_solver->SetParams(params);
431+
}
432+
} else {
433+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO",
434+
"Unknown CG method: " + inp.directmin_cg_method);
435+
}
424436
}
425437

426438

@@ -437,17 +449,17 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_params(UnitCell& ucell, const Input_para&
437449
}
438450

439451

440-
// Set parameters for different solvers based on optimization route
441-
if (this->rdmft_loop_layer_ == 1 && this->solver != nullptr) {
442-
// this->solver->SetParams(params);
443-
} else if (this->rdmft_loop_layer_ == 2) {
444-
if (this->orbital_solver != nullptr) {
445-
// this->orbital_solver->SetParams(params);
446-
}
447-
if (this->occupation_solver != nullptr) {
448-
// this->occupation_solver->SetParams(params);
449-
}
450-
}
452+
// // Set parameters for different solvers based on optimization route
453+
// if (this->rdmft_loop_layer_ == 1 && this->solver != nullptr) {
454+
// // this->solver->SetParams(params);
455+
// } else if (this->rdmft_loop_layer_ == 2) {
456+
// if (this->orbital_solver != nullptr) {
457+
// // this->orbital_solver->SetParams(params);
458+
// }
459+
// if (this->occupation_solver != nullptr) {
460+
// // this->occupation_solver->SetParams(params);
461+
// }
462+
// }
451463
}
452464

453465
// ===============================================
@@ -709,70 +721,104 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_initial_guess(UnitCell& ucell, const Inpu
709721
ModuleBase::TITLE(this->classname, "setup_initial_guess");
710722

711723
X = this->prob->GetDomain() -> RandominManifold();
724+
// this->X.Print("Initial variable for DirectMin:");
712725

713726
ModuleESolver::ESolver_KS_LCAO<TK, TR> * p_esolver = new ESolver_KS_LCAO<TK, TR>();
714727
p_esolver->before_all_runners(ucell, inp);
715728
p_esolver->runner(ucell, 0);
716729

717730
int nk = p_esolver->get_nk();
718-
int ncol = p_esolver->get_nrow(); // note the ncol in KS solver is nrow in RDMFT problem, nbasis
719-
int nrow = p_esolver->get_ncol(); // note the nrow in KS solver is ncol in RDMFT problem, nbands
731+
int nbasis = p_esolver->get_ncol(); // note the ncol in KS solver, nbasis
732+
int nbands= inp.nbands; // note the nrow in KS solver, nbands
733+
734+
// /*
735+
std::cout << "X.GetElement(0).Getrow(): " << X.GetElement(0).Getrow() << std::endl;
736+
std::cout << "X.GetElement(0).Getcol(): " << X.GetElement(0).Getcol() << std::endl;
737+
std::cout << "X.Getnumofelements(): " << X.Getnumofelements() << std::endl;
738+
739+
std::cout << "nk from KS solver: " << nk << std::endl;
740+
std::cout << "nbasis from KS solver: " << nbasis << std::endl;
741+
std::cout << "nbands from KS solver: " << nbands << std::endl;
742+
// */
720743

721744
if( nk != X.Getnumofelements() / 2)
722745
{
723746
std::cout << "nk from KS solver: " << nk << ", nk from RDMFT problem: " << X.Getnumofelements() / 2<< std::endl;
724747
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The nk points from KS solver and RDMFT problem do not match.");
725748
}
726-
if ( ncol != X.GetElement(0).Getcol())
749+
if ( nbasis != X.GetElement(0).Getrow())
727750
{
728-
std::cout << "ncol from KS solver: " << ncol << ", ncol from RDMFT problem: " << X.GetElement(0).Getcol()<< std::endl;
751+
std::cout << "nbasis from KS solver: " << nbasis << ", nbasis from RDMFT problem: " << X.GetElement(0).Getrow()<< std::endl;
752+
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The number of nbasis from KS solver and RDMFT problem do not match.");
753+
}
754+
if ( nbands != X.GetElement(0).Getcol())
755+
{
756+
std::cout << "nbands from KS solver: " << nbands << ", nbands from RDMFT problem: " << X.GetElement(0).Getcol()<< std::endl;
729757
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The number of bands from KS solver and RDMFT problem do not match.");
730758
}
731-
if ( nrow != X.GetElement(0).Getrow())
732-
{
733-
std::cout << "nrow from KS solver: " << nrow << ", nrow from RDMFT problem: " << X.GetElement(0).Getrow()<< std::endl;
734-
ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "The number of rows from KS solver and RDMFT problem do not match.");
735-
}
736-
737-
// seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
738-
// dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
739-
// initial_variable)
740-
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
741-
// dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
742-
// initial_variable)
743-
// || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
744-
// dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
745-
// initial_variable);
746-
// }
747-
// catch (const std::exception& ex)
748-
// {
749-
// seed_error = ex.what();
750-
// }
751-
752-
// ModuleESolver::clean_esolver(p_esolver, false);
753-
754-
// if (!seeded)
755-
// {
756-
// if (!seed_error.empty())
757-
// {
758-
// const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
759-
// ModuleBase::WARNING("ESolver_RDMFT_LCAO", message.c_str());
760-
// }
761-
// else
762-
// {
763-
// ModuleBase::WARNING("ESolver_RDMFT_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
764-
// }
765-
// }
766-
// }
767-
// else
768-
// {
769-
// ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Only 'ks' is supported for directmin_init_method currently");
770-
// }
759+
760+
const int occ_offset = nk;
761+
762+
// set up occupation part
763+
for (int ik = 0; ik < nk; ++ik)
764+
{
765+
Element& occ_elem = X.GetElement(occ_offset + ik);
766+
if (occ_elem.Getrow() != nbands || occ_elem.Getcol() != 1 || occ_elem.Getiscomplex())
767+
{
768+
throw std::runtime_error("KSStateToVariable encountered an occupation component with unexpected shape");
769+
}
770+
771+
double* occ_buffer = occ_elem.ObtainWriteEntireData();
772+
for (int ib = 0; ib < nbands; ++ib)
773+
{
774+
K_Vectors & kv = p_esolver->get_kv();
775+
occ_buffer[ib] = p_esolver->get_pelec()->wg(ik, ib) / kv.wk[ik];
776+
std::cout << "Initial occupation for k-point " << ik << ", band " << ib << ": " << occ_buffer[ib] << std::endl;
777+
}
778+
779+
}
780+
781+
// set up wavefunction part
782+
for (int ik = 0; ik < nk; ++ik)
783+
{
784+
Element& wf_elem = X.GetElement(ik);
785+
786+
double* raw_data = wf_elem.ObtainWriteEntireData();
787+
const bool is_complex = wf_elem.Getiscomplex();
788+
const int element_length = wf_elem.Getlength();
789+
std::fill(raw_data, raw_data + element_length, 0.0);
790+
std::complex<double>* complex_buffer = reinterpret_cast<std::complex<double>*>(raw_data);
791+
792+
psi::Psi<TK> * wfc = p_esolver->get_psi();
793+
wfc->fix_k(ik);
794+
795+
for (int ib_local = 0; ib_local < nbands; ++ib_local)
796+
{
797+
wfc->fix_b(ib_local);
798+
799+
for (int ir_local = 0; ir_local < nbasis; ++ir_local)
800+
{
801+
802+
const int linear_index = ir_local + nbasis * ib_local;
803+
const TK value = (*wfc)(ir_local);
804+
805+
if (is_complex)
806+
{
807+
complex_buffer[linear_index] = std::complex<double>(std::real(value), std::imag(value));
808+
}
809+
else
810+
{
811+
raw_data[linear_index] = static_cast<double>(std::real(value));
812+
}
813+
}
814+
}
815+
}
816+
771817

772818
// this->X = std::move(initial_variable);
773-
this->X.Print("Initial variable for DirectMin:");
819+
if(true) // for debug
820+
this->X.Print("Initial variable for DirectMin:");
774821
// }
775-
// Implementation of initial guess setup
776822
}
777823

778824

0 commit comments

Comments
 (0)