@@ -365,86 +365,23 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_solver(UnitCell& ucell, const Input_para&
365365 // for the moment, use random use random orbitals and fixed weights
366366 else if (inp.directmin_objective == " rdmft" )
367367 {
368- // Manifold* domain = this->prob->GetDomain();
369- // if (domain == nullptr)
370- // {
371- // ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "RDMFT problem domain is undefined");
372- // }
373-
374- Variable initial_variable = this ->prob ->GetDomain ()->RandominManifold ();
375- bool seeded = false ;
376- std::string seed_error;
377-
378368 if (this ->init_method_ == " ks" )
379369 {
380- ModuleESolver::ESolver* p_esolver = nullptr ;
381- if (PARAM.globalv .gamma_only_local )
382- {
383- p_esolver = new ESolver_KS_LCAO<double , double >();
384- }
385- else if (PARAM.inp .nspin < 4 )
386- {
387- p_esolver = new ESolver_KS_LCAO<std::complex <double >, double >();
388- }
389- else
390- {
391- p_esolver = new ESolver_KS_LCAO<std::complex <double >, std::complex <double >>();
392- }
393-
394- try
395- {
396- // temporarily use esolver_type=ksdft
397- p_esolver->before_all_runners (ucell, inp);
398-
399- std::cout << " Setting up KS solver for seeding DirectMin variable" << std::endl;
400- p_esolver->runner (ucell, 0 );
401- // check if the runner was successful
402- if (!p_esolver->conv_esolver )
403- {
404- throw std::runtime_error (" KS solver did not converge" );
405- }
406- else
407- {
408- std::cout << " KS solver converged successfully." << std::endl;
409- }
410-
411- // seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
412- // dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
413- // initial_variable)
414- // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
415- // dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
416- // initial_variable)
417- // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
418- // dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
419- // initial_variable);
420- }
421- catch (const std::exception& ex)
422- {
423- seed_error = ex.what ();
424- }
425-
426- ModuleESolver::clean_esolver (p_esolver, false );
427-
428- if (!seeded)
429- {
430- if (!seed_error.empty ())
431- {
432- const std::string message = " Unable to seed DirectMin variable from KS solver: " + seed_error;
433- ModuleBase::WARNING (" ESolver_RDMFT_LCAO" , message.c_str ());
434- }
435- else
436- {
437- ModuleBase::WARNING (" ESolver_RDMFT_LCAO" , " Unable to seed DirectMin variable from KS solver; using random initialization" );
438- }
370+ this ->setup_initial_guess (ucell, inp);
371+ }
372+ else if (this ->init_method_ == " random" )
373+ {
374+ Manifold* space = this ->prob ->GetDomain ();
375+ if (!space) {
376+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " Manifold is not defined in RDMFT_LCAO class" );
377+ return ;
439378 }
379+ this ->X = space->RandominManifold ();
440380 }
441381 else
442382 {
443- ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " Only 'ks' is supported for directmin_init_method currently " );
383+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " Unknown init_method: " + this -> init_method_ );
444384 }
445-
446- this ->X = std::move (initial_variable);
447- this ->X .Print (" Initial variable for DirectMin:" );
448385 }
449386
450387 // next set up the solver based on directmin_solver, sd, cg, bfgs, etc.
@@ -766,6 +703,79 @@ void ESolver_RDMFT_LCAO<TK, TR>::cleanup_problems()
766703 }
767704}
768705
706+ template <typename TK, typename TR>
707+ void ESolver_RDMFT_LCAO<TK, TR>::setup_initial_guess(UnitCell& ucell, const Input_para& inp)
708+ {
709+ ModuleBase::TITLE (this ->classname , " setup_initial_guess" );
710+
711+ X = this ->prob ->GetDomain () -> RandominManifold ();
712+
713+ ModuleESolver::ESolver_KS_LCAO<TK, TR> * p_esolver = new ESolver_KS_LCAO<TK, TR>();
714+ p_esolver->before_all_runners (ucell, inp);
715+ p_esolver->runner (ucell, 0 );
716+
717+ int nk = p_esolver->get_nk ();
718+ int ncol = p_esolver->get_nrow (); // note the ncol in KS solver is nrow in RDMFT problem, nbasis
719+ int nrow = p_esolver->get_ncol (); // note the nrow in KS solver is ncol in RDMFT problem, nbands
720+
721+ if ( nk != X.Getnumofelements () / 2 )
722+ {
723+ std::cout << " nk from KS solver: " << nk << " , nk from RDMFT problem: " << X.Getnumofelements () / 2 << std::endl;
724+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The nk points from KS solver and RDMFT problem do not match." );
725+ }
726+ if ( ncol != X.GetElement (0 ).Getcol ())
727+ {
728+ std::cout << " ncol from KS solver: " << ncol << " , ncol from RDMFT problem: " << X.GetElement (0 ).Getcol ()<< std::endl;
729+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The number of bands from KS solver and RDMFT problem do not match." );
730+ }
731+ if ( nrow != X.GetElement (0 ).Getrow ())
732+ {
733+ std::cout << " nrow from KS solver: " << nrow << " , nrow from RDMFT problem: " << X.GetElement (0 ).Getrow ()<< std::endl;
734+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The number of rows from KS solver and RDMFT problem do not match." );
735+ }
736+
737+ // seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
738+ // dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
739+ // initial_variable)
740+ // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
741+ // dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
742+ // initial_variable)
743+ // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
744+ // dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
745+ // initial_variable);
746+ // }
747+ // catch (const std::exception& ex)
748+ // {
749+ // seed_error = ex.what();
750+ // }
751+
752+ // ModuleESolver::clean_esolver(p_esolver, false);
753+
754+ // if (!seeded)
755+ // {
756+ // if (!seed_error.empty())
757+ // {
758+ // const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
759+ // ModuleBase::WARNING("ESolver_RDMFT_LCAO", message.c_str());
760+ // }
761+ // else
762+ // {
763+ // ModuleBase::WARNING("ESolver_RDMFT_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
764+ // }
765+ // }
766+ // }
767+ // else
768+ // {
769+ // ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Only 'ks' is supported for directmin_init_method currently");
770+ // }
771+
772+ // this->X = std::move(initial_variable);
773+ this ->X .Print (" Initial variable for DirectMin:" );
774+ // }
775+ // Implementation of initial guess setup
776+ }
777+
778+
769779// Add interface methods to access ABACUS data structures
770780// const psi::Psi<double>* ESolver_RDMFT_LCAO<TK, TR>::get_psi() const
771781// {
0 commit comments