Skip to content

Commit 0680c75

Browse files
author
Kai Luo
committed
Enhance RDMFT module with orbital and occupation optimization classes and interface methods for ABACUS data access
1 parent a48a116 commit 0680c75

File tree

4 files changed

+551
-249
lines changed

4 files changed

+551
-249
lines changed

source/source_esolver/esolver_directmin_lcao.cpp

Lines changed: 91 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ void ESolver_DirectMin_LCAO::others(UnitCell& ucell, const int istep)
185185
//template <typename TK, typename TR>
186186
void ESolver_DirectMin_LCAO::runner(UnitCell& ucell, const int istep)
187187
{
188-
ModuleBase::TITLE("ESolver_DirectMin_LCAO", "runner");
188+
ModuleBase::TITLE(this->classname, "runner");
189189
ModuleBase::timer::tick(this->classname, "runner");
190190

191191
std::cout << "Running DirectMin ESolver with RDMFT" << std::endl;
@@ -399,280 +399,139 @@ void ESolver_DirectMin_LCAO::setup_params(UnitCell& ucell, const Input_para& inp
399399

400400
void ESolver_DirectMin_LCAO::setup_hybrid_optimization(UnitCell& ucell, const Input_para& inp)
401401
{
402-
ModuleBase::TITLE("ESolver_DirectMin_LCAO", "setup_hybrid_optimization");
403-
std::cout << "Setting up hybrid two-level optimization" << std::endl;
404-
405-
// Setup problems for orbital and occupation optimization
406-
this->setup_orbital_problem(ucell, inp);
407-
this->setup_occupation_problem(ucell, inp);
402+
ModuleBase::TITLE(this->classname, "setup_hybrid_optimization");
408403

409-
// Setup solvers for orbital and occupation optimization
410-
this->setup_orbital_solver(ucell, inp);
411-
this->setup_occupation_solver(ucell, inp);
412-
}
413-
414-
void ESolver_DirectMin_LCAO::setup_orbital_problem(UnitCell& ucell, const Input_para& inp)
415-
{
416-
std::cout << "Setting up orbital optimization problem" << std::endl;
404+
// Read RDMFT parameters
405+
loop_layer_ = GlobalV::PARAM_RDMFT.rdmft_loop_layer;
406+
max_iter_occ_ = GlobalV::PARAM_RDMFT.rdmft_max_iter_occ;
407+
max_iter_orb_ = GlobalV::PARAM_RDMFT.rdmft_max_iter_orb;
408+
occ_opt_method_ = GlobalV::PARAM_RDMFT.rdmft_occ_opt_method;
409+
orb_opt_method_ = GlobalV::PARAM_RDMFT.rdmft_orb_opt_method;
417410

418-
// Create initial occupation vector (from DFT or previous iteration)
419-
Vector initial_occupations; // TODO: Initialize from KS calculation
411+
// Setup orbital and occupation problems
412+
setup_orbital_problem(ucell, inp);
413+
setup_occupation_problem(ucell, inp);
420414

421-
// Create RDMFT_Orbital problem with fixed occupations
422-
// this->orbital_prob = new RDMFT_Orbital(ucell, inp, initial_occupations);
415+
// Setup orbital and occupation solvers
416+
setup_orbital_solver(ucell, inp);
417+
setup_occupation_solver(ucell, inp);
423418

424-
// For now, create a placeholder (will be replaced with actual RDMFT_Orbital)
425-
if (inp.directmin_objective == "test") {
426-
this->orbital_prob = new TEST(ucell, inp);
427-
428-
Manifold* space = this->orbital_prob->GetDomain();
429-
if (space != nullptr) {
430-
X_orbitals = space->RandominManifold();
431-
} else {
432-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Orbital manifold not defined");
433-
}
434-
X_orbitals.Print("Initial Orbitals");
435-
}
419+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "Hybrid optimization setup complete");
436420
}
437421

438-
void ESolver_DirectMin_LCAO::setup_occupation_problem(UnitCell& ucell, const Input_para& inp)
422+
void ESolver_DirectMin_LCAO::setup_orbital_problem(UnitCell& ucell, const Input_para& inp)
439423
{
440-
std::cout << "Setting up occupation optimization problem" << std::endl;
424+
// Determine dimensions
425+
int nbands = GlobalV::NBANDS;
426+
int nlocal = GlobalV::NLOCAL;
427+
bool gamma_only = GlobalV::GAMMA_ONLY_LOCAL;
441428

442-
// Create initial orbital matrix (from DFT or previous iteration)
443-
Variable initial_orbitals; // TODO: Initialize from KS calculation
429+
// Create orbital problem - pass this ESolver as parameter
430+
orbital_prob = new ModuleRDMFT::RDMFT_Orbital(nlocal, nbands, gamma_only, static_cast<const void*>(this));
444431

445-
// Create RDMFT_Occupation problem with fixed orbitals
446-
// this->occupation_prob = new RDMFT_Occupation(ucell, inp, initial_orbitals);
432+
// Set the domain (Stiefel manifold)
433+
ROPTLITE::Manifold* stiefel_manifold = ModuleRDMFT::create_stiefel_manifold(nlocal, nbands, gamma_only);
434+
orbital_prob->SetDomain(stiefel_manifold);
447435

448-
// For now, create a placeholder (will be replaced with actual RDMFT_Occupation)
449-
if (inp.directmin_objective == "test") {
450-
this->occupation_prob = new TEST(ucell, inp);
451-
452-
Manifold* space = this->occupation_prob->GetDomain();
453-
if (space != nullptr) {
454-
X_occupations = space->RandominManifold();
455-
} else {
456-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Occupation manifold not defined");
457-
}
458-
X_occupations.Print("Initial Occupations");
459-
}
460-
}
461-
462-
void ESolver_DirectMin_LCAO::setup_orbital_solver(UnitCell& ucell, const Input_para& inp)
463-
{
464-
std::cout << "Setting up orbital solver" << std::endl;
436+
// Initialize orbital variables
437+
X_orbitals = stiefel_manifold->RandominManifold();
465438

466-
if (this->orbital_solver != nullptr) {
467-
delete this->orbital_solver;
468-
this->orbital_solver = nullptr;
469-
}
470-
471-
if (this->orb_opt_method_ == "cg") {
472-
this->orbital_solver = new RCG(this->orbital_prob, &X_orbitals);
473-
} else if (this->orb_opt_method_ == "sd") {
474-
this->orbital_solver = new RSD(this->orbital_prob, &X_orbitals);
475-
} else if (this->orb_opt_method_ == "bfgs") {
476-
this->orbital_solver = new RBFGS(this->orbital_prob, &X_orbitals);
477-
} else {
478-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Unknown orbital optimization method: " + this->orb_opt_method_);
479-
}
480-
481-
// this->orbital_solver->Verbose = FINALRESULT;
482-
std::cout << "Orbital solver: " << this->orb_opt_method_ << std::endl;
439+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "Orbital problem setup complete");
440+
std::cout << "Orbital manifold: " << (gamma_only ? "Stiefel (real)" : "CStiefel (complex)") << std::endl;
441+
std::cout << "Orbital dimensions: " << nlocal << " x " << nbands << std::endl;
483442
}
484443

485-
void ESolver_DirectMin_LCAO::setup_occupation_solver(UnitCell& ucell, const Input_para& inp)
486-
{
487-
std::cout << "Setting up occupation solver" << std::endl;
488-
489-
if (this->occupation_solver != nullptr) {
490-
delete this->occupation_solver;
491-
this->occupation_solver = nullptr;
492-
}
493-
494-
if (this->occ_opt_method_ == "cg") {
495-
this->occupation_solver = new RCG(this->occupation_prob, &X_occupations);
496-
} else if (this->occ_opt_method_ == "sd") {
497-
this->occupation_solver = new RSD(this->occupation_prob, &X_occupations);
498-
} else if (this->occ_opt_method_ == "bfgs") {
499-
this->occupation_solver = new RBFGS(this->occupation_prob, &X_occupations);
500-
} else {
501-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Unknown occupation optimization method: " + this->occ_opt_method_);
502-
}
503-
504-
// this->occupation_solver->Verbose = FINALRESULT;
505-
std::cout << "Occupation solver: " << this->occ_opt_method_ << std::endl;
506-
}
507-
508-
void ESolver_DirectMin_LCAO::run_hybrid_optimization(UnitCell& ucell)
444+
void ESolver_DirectMin_LCAO::setup_occupation_problem(UnitCell& ucell, const Input_para& inp)
509445
{
510-
std::cout << "Starting hybrid two-level optimization" << std::endl;
446+
// Number of occupation numbers to optimize
447+
int n_occ = GlobalV::NBANDS;
511448

512-
double energy_old = 1e10;
513-
double energy_new = 0.0;
514-
int iter = 0;
515-
516-
while (iter < this->max_iter && !this->check_convergence()) {
517-
iter++;
518-
std::cout << "\n=== Hybrid Iteration " << iter << " ===" << std::endl;
519-
520-
// Step 1: Optimize orbitals with fixed occupations
521-
std::cout << "Optimizing orbitals..." << std::endl;
522-
this->optimize_orbitals_fixed_occupations();
523-
524-
// Step 2: Optimize occupations with fixed orbitals
525-
std::cout << "Optimizing occupations..." << std::endl;
526-
this->optimize_occupations_fixed_orbitals();
527-
528-
// Check convergence
529-
energy_new = this->cal_energy();
530-
this->print_iteration_info(iter, energy_new, 0.0); // TODO: Add gradient norm
531-
532-
if (std::abs(energy_new - energy_old) < this->tol_f) {
533-
std::cout << "Converged in energy!" << std::endl;
534-
break;
535-
}
536-
energy_old = energy_new;
537-
}
538-
539-
std::cout << "Hybrid optimization completed after " << iter << " iterations" << std::endl;
540-
}
541-
542-
void ESolver_DirectMin_LCAO::optimize_orbitals_fixed_occupations()
543-
{
544-
if (this->orbital_solver == nullptr) {
545-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Orbital solver not initialized");
546-
}
449+
// Create occupation problem - pass this ESolver as parameter
450+
occupation_prob = new ModuleRDMFT::RDMFT_Occupation(n_occ, static_cast<const void*>(this));
547451

548-
// Run orbital optimization for specified number of iterations
549-
for (int i = 0; i < this->max_iter_orb_; ++i) {
550-
// TODO: Implement single step of orbital optimization
551-
// For now, just run the solver
552-
break; // Placeholder
553-
}
554-
}
555-
556-
void ESolver_DirectMin_LCAO::optimize_occupations_fixed_orbitals()
557-
{
558-
if (this->occupation_solver == nullptr) {
559-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Occupation solver not initialized");
560-
}
452+
// Set the domain (Euclidean space)
453+
ROPTLITE::Manifold* euclidean_manifold = new ROPTLITE::Euclidean(n_occ);
454+
occupation_prob->SetDomain(euclidean_manifold);
561455

562-
// Run occupation optimization for specified number of iterations
563-
for (int i = 0; i < this->max_iter_occ_; ++i) {
564-
// TODO: Implement single step of occupation optimization
565-
// For now, just run the solver
566-
break; // Placeholder
567-
}
568-
}
569-
570-
// ===============================================
571-
// Route 2: Joint Optimization Functions
572-
// ===============================================
573-
574-
void ESolver_DirectMin_LCAO::setup_joint_optimization(UnitCell& ucell, const Input_para& inp)
575-
{
576-
ModuleBase::TITLE("ESolver_DirectMin_LCAO", "setup_joint_optimization");
577-
std::cout << "Setting up joint optimization on product manifold" << std::endl;
456+
// Initialize occupation variables
457+
X_occupations = euclidean_manifold->RandominManifold();
578458

579-
this->setup_multimanifold_problem(ucell, inp);
459+
ModuleBase::GlobalFunc::OUT(GlobalV::ofs_running, "Occupation problem setup complete");
460+
std::cout << "Occupation manifold: Euclidean" << std::endl;
461+
std::cout << "Number of occupation numbers: " << n_occ << std::endl;
580462
}
581463

582-
void ESolver_DirectMin_LCAO::setup_multimanifold_problem(UnitCell& ucell, const Input_para& inp)
464+
// Add missing utility methods
465+
void ESolver_DirectMin_LCAO::print_iteration_info(int iter, double energy, double grad_norm) const
583466
{
584-
std::cout << "Setting up multimanifold problem" << std::endl;
585-
586-
if (inp.directmin_objective == "rdmft") {
587-
// Create full RDMFT problem on product manifold
588-
// this->prob = new RDMFT(ucell, inp);
589-
590-
// TODO: Get initial point from RDMFT problem
591-
// X = this->prob->GetInitialPoint();
592-
593-
std::cout << "RDMFT problem on product manifold created" << std::endl;
594-
} else {
595-
// Fallback to test problem for development
596-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "RDMFT not implemented yet, use test objective");
467+
std::cout << "RDMFT Iteration " << iter << ": Energy = " << std::scientific << std::setprecision(8) << energy;
468+
if (grad_norm > 0.0) {
469+
std::cout << ", |grad| = " << grad_norm;
597470
}
598-
}
599-
600-
void ESolver_DirectMin_LCAO::run_joint_optimization(UnitCell& ucell)
601-
{
602-
std::cout << "Starting joint optimization" << std::endl;
471+
std::cout << std::endl;
603472

604-
if (this->solver == nullptr) {
605-
ModuleBase::WARNING_QUIT("ESolver_DirectMin_LCAO", "Joint solver not initialized");
473+
GlobalV::ofs_running << "RDMFT Iteration " << iter << ": Energy = " << energy;
474+
if (grad_norm > 0.0) {
475+
GlobalV::ofs_running << ", |grad| = " << grad_norm;
606476
}
607-
608-
// Run the optimization
609-
this->solver->Run();
610-
611-
// Print final results
612-
this->solver->GetXopt().Print("Final Joint Optimization Result");
613-
std::cout << "Joint optimization completed" << std::endl;
477+
GlobalV::ofs_running << std::endl;
614478
}
615479

616-
// ===============================================
617-
// Helper Functions
618-
// ===============================================
619-
620480
bool ESolver_DirectMin_LCAO::check_convergence() const
621481
{
622-
// TODO: Implement proper convergence checking
623-
// For now, return false to let iteration counter control convergence
482+
// Simple convergence check - can be improved
483+
// For now, rely on the main optimization loop convergence
624484
return false;
625485
}
626486

627-
void ESolver_DirectMin_LCAO::print_iteration_info(int iter, double energy, double grad_norm) const
628-
{
629-
std::cout << "Iteration " << iter << ": Energy = " << std::scientific << energy
630-
<< ", Grad norm = " << grad_norm << std::endl;
631-
GlobalV::ofs_running << "Iteration " << iter << ": Energy = " << std::scientific << energy
632-
<< ", Grad norm = " << grad_norm << std::endl;
633-
}
634-
635487
void ESolver_DirectMin_LCAO::cleanup_solvers()
636488
{
637-
// Clean up joint optimization solver
638-
if (this->solver != nullptr) {
639-
delete this->solver;
640-
this->solver = nullptr;
489+
if (solver != nullptr) {
490+
delete solver;
491+
solver = nullptr;
641492
}
642-
643-
// Clean up hybrid optimization solvers
644-
if (this->orbital_solver != nullptr) {
645-
delete this->orbital_solver;
646-
this->orbital_solver = nullptr;
493+
if (orbital_solver != nullptr) {
494+
delete orbital_solver;
495+
orbital_solver = nullptr;
647496
}
648-
649-
if (this->occupation_solver != nullptr) {
650-
delete this->occupation_solver;
651-
this->occupation_solver = nullptr;
497+
if (occupation_solver != nullptr) {
498+
delete occupation_solver;
499+
occupation_solver = nullptr;
652500
}
653501
}
654502

655503
void ESolver_DirectMin_LCAO::cleanup_problems()
656504
{
657-
// Clean up joint optimization problem
658-
if (this->prob != nullptr) {
659-
delete this->prob;
660-
this->prob = nullptr;
505+
if (prob != nullptr) {
506+
delete prob;
507+
prob = nullptr;
661508
}
662-
663-
// Clean up hybrid optimization problems
664-
if (this->orbital_prob != nullptr) {
665-
delete this->orbital_prob;
666-
this->orbital_prob = nullptr;
509+
if (orbital_prob != nullptr) {
510+
delete orbital_prob;
511+
orbital_prob = nullptr;
667512
}
668-
669-
if (this->occupation_prob != nullptr) {
670-
delete this->occupation_prob;
671-
this->occupation_prob = nullptr;
513+
if (occupation_prob != nullptr) {
514+
delete occupation_prob;
515+
occupation_prob = nullptr;
516+
}
517+
if (manifold != nullptr) {
518+
delete manifold;
519+
manifold = nullptr;
672520
}
673521
}
674522

675-
// template class ESolver_DirectMin_LCAO<double, double>;
676-
// template class ESolver_DirectMin_LCAO<std::complex<double>, double>;
677-
// template class ESolver_DirectMin_LCAO<std::complex<double>, std::complex<double>>;
523+
// Add interface methods to access ABACUS data structures
524+
const psi::Psi<double>* ESolver_DirectMin_LCAO::get_psi() const
525+
{
526+
return this->psi;
527+
}
528+
529+
const elecstate::ElecState* ESolver_DirectMin_LCAO::get_pelec() const
530+
{
531+
return this->pelec;
532+
}
533+
534+
const hamilt::Hamilt<double>* ESolver_DirectMin_LCAO::get_hamilt() const
535+
{
536+
return this->p_hamilt;
678537
}

0 commit comments

Comments
 (0)