Skip to content

Commit b28d5bf

Browse files
author
Kai Luo
committed
Refactor ESolver_DirectMin_LCAO: Enhance parameter checking and clean up commented-out optimization methods
1 parent e862fd8 commit b28d5bf

File tree

3 files changed

+57
-48
lines changed

3 files changed

+57
-48
lines changed

source/source_esolver/directmin_problems/rdmft_lcao.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,16 @@ RDMFT_LCAO<TK, TR>::RDMFT_LCAO(UnitCell& ucell, const Input_para& inp)
8585
{
8686
this->mani_st = new Stiefel(stiefel_n, stiefel_p);
8787
this->mani_st->ChooseParamsSet1(); // change later
88+
this->mani_st->CheckParams();
89+
8890
this->mani = new ProductManifold(numoftypes, this->mani_st, multitude, this->mani_euc, multitude);
8991
}
9092
else
9193
{
9294
this->mani_cst = new CStiefel(stiefel_n, stiefel_p);
9395
this->mani_cst->ChooseParamsSet1(); // change later
96+
this->mani_cst->CheckParams();
97+
9498
this->mani = new ProductManifold(numoftypes, this->mani_cst, multitude, this->mani_euc, multitude);
9599
}
96100

@@ -135,6 +139,8 @@ template <typename TK, typename TR>
135139
realdp RDMFT_LCAO<TK, TR>::f(const Variable& x) const
136140
{
137141
this->rdmft_solver.cal_Energy(1); // 1 means not to calculate forces here
142+
143+
std::cout << "Called f() in RDMFT_LCAO, Etotal = " << this->rdmft_solver.Etotal << std::endl;
138144
return this->rdmft_solver.Etotal;
139145
}
140146

source/source_esolver/esolver_directmin_lcao.cpp

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ void ESolver_DirectMin_LCAO::runner(UnitCell& ucell, const int istep)
211211
if (this->rdmft_loop_layer_ == 1) {
212212
// Route 2: Joint optimization on product manifold
213213
std::cout << "Using joint optimization (Route 1)" << std::endl;
214-
this->run_joint_optimization(ucell);
214+
this->solver->Run();
215+
215216
} else if (this->rdmft_loop_layer_ == 2) {
216217
// Route 1: Hybrid two-level optimization
217218
std::cout << "Using hybrid two-level optimization (Route 2)" << std::endl;
@@ -331,15 +332,15 @@ void ESolver_DirectMin_LCAO::setup_problem(UnitCell& ucell, const Input_para& in
331332
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Only LCAO basis is supported for RDMFT in DirectMin");
332333
}
333334

334-
// check optimization route
335-
if(inp.rdmftp.rdmft_loop_layer == 1)
336-
{
337-
this->setup_joint_optimization(ucell, inp);
338-
}
339-
else if(inp.rdmftp.rdmft_loop_layer == 2)
340-
{
341-
this->setup_hybrid_optimization(ucell, inp);
342-
}
335+
// // check optimization route
336+
// if(inp.rdmftp.rdmft_loop_layer == 1)
337+
// {
338+
// this->setup_joint_optimization(ucell, inp);
339+
// }
340+
// else if(inp.rdmftp.rdmft_loop_layer == 2)
341+
// {
342+
// this->setup_hybrid_optimization(ucell, inp);
343+
// }
343344
}
344345
else
345346
{
@@ -381,6 +382,7 @@ void ESolver_DirectMin_LCAO::setup_solver(UnitCell& ucell, const Input_para& inp
381382
else if (inp.directmin_objective == "rdmft")
382383
{
383384
X = this->prob->GetDomain()->RandominManifold();
385+
384386
}
385387

386388
// next set up the solver based on directmin_solver, sd, cg, bfgs, etc.
@@ -629,46 +631,47 @@ void ESolver_DirectMin_LCAO::optimize_occupations_fixed_orbitals()
629631
std::cout << "[DirectMin] Running occupation solver (stub)." << std::endl;
630632
}
631633

632-
void ESolver_DirectMin_LCAO::setup_joint_optimization(UnitCell& ucell, const Input_para& inp)
633-
{
634-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "setup_joint_optimization", "Stub implementation");
635-
636-
if (inp.directmin_objective == "test")
637-
{
638-
this->prob = new TEST(ucell, inp);
639-
}
640-
else
641-
{
642-
std::cout << "[DirectMin] RDMFT joint optimization problem setup is not implemented yet." << std::endl;
643-
this->prob = nullptr;
644-
}
645-
646-
if (this->prob != nullptr)
647-
{
648-
Manifold* space = this->prob->GetDomain();
649-
if (space != nullptr)
650-
{
651-
this->X = space->RandominManifold();
652-
}
653-
}
654-
}
634+
// void ESolver_DirectMin_LCAO::setup_joint_optimization(UnitCell& ucell, const Input_para& inp)
635+
// {
636+
// ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "setup_joint_optimization", "Stub implementation");
637+
638+
// if (inp.directmin_objective == "test")
639+
// {
640+
// this->prob = new TEST(ucell, inp);
641+
// }
642+
// else
643+
// {
644+
// this->prob = new RDMFT_LCAO
645+
// // std::cout << "[DirectMin] RDMFT joint optimization problem setup is not implemented yet." << std::endl;
646+
// // this->prob = nullptr;
647+
// }
648+
649+
// if (this->prob != nullptr)
650+
// {
651+
// Manifold* space = this->prob->GetDomain();
652+
// if (space != nullptr)
653+
// {
654+
// this->X = space->RandominManifold();
655+
// }
656+
// }
657+
// }
655658

656-
void ESolver_DirectMin_LCAO::run_joint_optimization(UnitCell& /*ucell*/)
657-
{
658-
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "run_joint_optimization", "Stub implementation");
659+
// void ESolver_DirectMin_LCAO::run_joint_optimization(UnitCell& /*ucell*/)
660+
// {
661+
// ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "run_joint_optimization", "Stub implementation");
659662

660-
if (this->solver == nullptr)
661-
{
662-
std::cout << "[DirectMin] Joint optimization solver not initialized." << std::endl;
663-
return;
664-
}
663+
// if (this->solver == nullptr)
664+
// {
665+
// std::cout << "[DirectMin] Joint optimization solver not initialized." << std::endl;
666+
// return;
667+
// }
665668

666-
std::cout << "[DirectMin] Running joint solver." << std::endl;
669+
// std::cout << "[DirectMin] Running joint solver." << std::endl;
667670

668-
this->solver->Run();
671+
// this->solver->Run();
669672

670-
std::cout << "[DirectMin] Joint optimization complete." << std::endl;
671-
}
673+
// std::cout << "[DirectMin] Joint optimization complete." << std::endl;
674+
// }
672675

673676
// Add missing utility methods
674677
void ESolver_DirectMin_LCAO::print_iteration_info(int iter, double energy, double grad_norm) const

source/source_esolver/esolver_directmin_lcao.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ class ESolver_DirectMin_LCAO : public ESolver_DirectMin
9999
void optimize_occupations_fixed_orbitals();
100100

101101
// Route 2: Joint optimization functions
102-
void setup_joint_optimization(UnitCell& ucell, const Input_para& inp);
103-
void setup_multimanifold_problem(UnitCell& ucell, const Input_para& inp);
104-
void run_joint_optimization(UnitCell& ucell);
102+
// void setup_joint_optimization(UnitCell& ucell, const Input_para& inp);
103+
// void setup_multimanifold_problem(UnitCell& ucell, const Input_para& inp);
104+
// void run_joint_optimization(UnitCell& ucell);
105105

106106
// Convergence checking
107107
bool check_convergence() const;

0 commit comments

Comments
 (0)