@@ -408,19 +408,31 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_params(UnitCell& ucell, const Input_para&
408408 if (inp.directmin_solver == " cg" )
409409 {
410410 static const std::map<std::string, RCGmethods> cg_methods = {
411- {" fr" , FLETCHER_REEVES},
412- {" pr" , POLAK_RIBIERE_MOD},
411+ {" fr" , FLETCHER_REEVES},
412+ {" pr" , POLAK_RIBIERE_MOD},
413413 {" hs" , HESTENES_STIEFEL},
414- {" frpr" , FR_PR},
415- {" dy" , DAI_YUAN},
414+ {" frpr" , FR_PR},
415+ {" dy" , DAI_YUAN},
416416 {" hz" , HAGER_ZHANG}};
417- auto it = cg_methods.find (inp.directmin_cg_method );
418- cg_index = (it != cg_methods.end ()) ? it->second : -1 ;
419-
420- // params.insert(std::pair<std::string, RCGmethods>
421- // (std::string("RCGmethod"), static_cast<RCGmethods>(cg_index)));
422417
423- // need to set the CG method for the solver
418+ auto it = cg_methods.find (inp.directmin_cg_method );
419+ cg_index = (it != cg_methods.end ()) ? static_cast <int >(it->second ) : -1 ;
420+
421+ std::cout << " CG index: " << cg_index << std::endl;
422+ if (cg_index != -1 ) {
423+ RCG* cg_solver = dynamic_cast <RCG*>(this ->solver );
424+ if (cg_solver == nullptr ) {
425+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" ,
426+ " CG solver requested, but current solver is not RCG." );
427+ } else {
428+ std::map<std::string, double > params;
429+ params[" RCGmethod" ] = static_cast <double >(cg_index);
430+ cg_solver->SetParams (params);
431+ }
432+ } else {
433+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" ,
434+ " Unknown CG method: " + inp.directmin_cg_method );
435+ }
424436 }
425437
426438
@@ -437,17 +449,17 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_params(UnitCell& ucell, const Input_para&
437449 }
438450
439451
440- // Set parameters for different solvers based on optimization route
441- if (this ->rdmft_loop_layer_ == 1 && this ->solver != nullptr ) {
442- // this->solver->SetParams(params);
443- } else if (this ->rdmft_loop_layer_ == 2 ) {
444- if (this ->orbital_solver != nullptr ) {
445- // this->orbital_solver->SetParams(params);
446- }
447- if (this ->occupation_solver != nullptr ) {
448- // this->occupation_solver->SetParams(params);
449- }
450- }
452+ // // Set parameters for different solvers based on optimization route
453+ // if (this->rdmft_loop_layer_ == 1 && this->solver != nullptr) {
454+ // // this->solver->SetParams(params);
455+ // } else if (this->rdmft_loop_layer_ == 2) {
456+ // if (this->orbital_solver != nullptr) {
457+ // // this->orbital_solver->SetParams(params);
458+ // }
459+ // if (this->occupation_solver != nullptr) {
460+ // // this->occupation_solver->SetParams(params);
461+ // }
462+ // }
451463}
452464
453465// ===============================================
@@ -709,70 +721,104 @@ void ESolver_RDMFT_LCAO<TK, TR>::setup_initial_guess(UnitCell& ucell, const Inpu
709721 ModuleBase::TITLE (this ->classname , " setup_initial_guess" );
710722
711723 X = this ->prob ->GetDomain () -> RandominManifold ();
724+ // this->X.Print("Initial variable for DirectMin:");
712725
713726 ModuleESolver::ESolver_KS_LCAO<TK, TR> * p_esolver = new ESolver_KS_LCAO<TK, TR>();
714727 p_esolver->before_all_runners (ucell, inp);
715728 p_esolver->runner (ucell, 0 );
716729
717730 int nk = p_esolver->get_nk ();
718- int ncol = p_esolver->get_nrow (); // note the ncol in KS solver is nrow in RDMFT problem, nbasis
719- int nrow = p_esolver->get_ncol (); // note the nrow in KS solver is ncol in RDMFT problem, nbands
731+ int nbasis = p_esolver->get_ncol (); // note the ncol in KS solver, nbasis
732+ int nbands= inp.nbands ; // note the nrow in KS solver, nbands
733+
734+ // /*
735+ std::cout << " X.GetElement(0).Getrow(): " << X.GetElement (0 ).Getrow () << std::endl;
736+ std::cout << " X.GetElement(0).Getcol(): " << X.GetElement (0 ).Getcol () << std::endl;
737+ std::cout << " X.Getnumofelements(): " << X.Getnumofelements () << std::endl;
738+
739+ std::cout << " nk from KS solver: " << nk << std::endl;
740+ std::cout << " nbasis from KS solver: " << nbasis << std::endl;
741+ std::cout << " nbands from KS solver: " << nbands << std::endl;
742+ // */
720743
721744 if ( nk != X.Getnumofelements () / 2 )
722745 {
723746 std::cout << " nk from KS solver: " << nk << " , nk from RDMFT problem: " << X.Getnumofelements () / 2 << std::endl;
724747 ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The nk points from KS solver and RDMFT problem do not match." );
725748 }
726- if ( ncol != X.GetElement (0 ).Getcol ())
749+ if ( nbasis != X.GetElement (0 ).Getrow ())
727750 {
728- std::cout << " ncol from KS solver: " << ncol << " , ncol from RDMFT problem: " << X.GetElement (0 ).Getcol ()<< std::endl;
751+ std::cout << " nbasis from KS solver: " << nbasis << " , nbasis from RDMFT problem: " << X.GetElement (0 ).Getrow ()<< std::endl;
752+ ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The number of nbasis from KS solver and RDMFT problem do not match." );
753+ }
754+ if ( nbands != X.GetElement (0 ).Getcol ())
755+ {
756+ std::cout << " nbands from KS solver: " << nbands << " , nbands from RDMFT problem: " << X.GetElement (0 ).Getcol ()<< std::endl;
729757 ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The number of bands from KS solver and RDMFT problem do not match." );
730758 }
731- if ( nrow != X.GetElement (0 ).Getrow ())
732- {
733- std::cout << " nrow from KS solver: " << nrow << " , nrow from RDMFT problem: " << X.GetElement (0 ).Getrow ()<< std::endl;
734- ModuleBase::WARNING_QUIT (" ESolver_RDMFT_LCAO" , " The number of rows from KS solver and RDMFT problem do not match." );
735- }
736-
737- // seeded = seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<double, double>*>(p_esolver),
738- // dynamic_cast<RDMFT_LCAO<double, double>*>(this->prob),
739- // initial_variable)
740- // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, double>*>(p_esolver),
741- // dynamic_cast<RDMFT_LCAO<std::complex<double>, double>*>(this->prob),
742- // initial_variable)
743- // || seed_variable_from_ks_solver(dynamic_cast<ESolver_KS_LCAO<std::complex<double>, std::complex<double>>*>(p_esolver),
744- // dynamic_cast<RDMFT_LCAO<std::complex<double>, std::complex<double>>*>(this->prob),
745- // initial_variable);
746- // }
747- // catch (const std::exception& ex)
748- // {
749- // seed_error = ex.what();
750- // }
751-
752- // ModuleESolver::clean_esolver(p_esolver, false);
753-
754- // if (!seeded)
755- // {
756- // if (!seed_error.empty())
757- // {
758- // const std::string message = "Unable to seed DirectMin variable from KS solver: " + seed_error;
759- // ModuleBase::WARNING("ESolver_RDMFT_LCAO", message.c_str());
760- // }
761- // else
762- // {
763- // ModuleBase::WARNING("ESolver_RDMFT_LCAO", "Unable to seed DirectMin variable from KS solver; using random initialization");
764- // }
765- // }
766- // }
767- // else
768- // {
769- // ModuleBase::WARNING_QUIT("ESolver_RDMFT_LCAO", "Only 'ks' is supported for directmin_init_method currently");
770- // }
759+
760+ const int occ_offset = nk;
761+
762+ // set up occupation part
763+ for (int ik = 0 ; ik < nk; ++ik)
764+ {
765+ Element& occ_elem = X.GetElement (occ_offset + ik);
766+ if (occ_elem.Getrow () != nbands || occ_elem.Getcol () != 1 || occ_elem.Getiscomplex ())
767+ {
768+ throw std::runtime_error (" KSStateToVariable encountered an occupation component with unexpected shape" );
769+ }
770+
771+ double * occ_buffer = occ_elem.ObtainWriteEntireData ();
772+ for (int ib = 0 ; ib < nbands; ++ib)
773+ {
774+ K_Vectors & kv = p_esolver->get_kv ();
775+ occ_buffer[ib] = p_esolver->get_pelec ()->wg (ik, ib) / kv.wk [ik];
776+ std::cout << " Initial occupation for k-point " << ik << " , band " << ib << " : " << occ_buffer[ib] << std::endl;
777+ }
778+
779+ }
780+
781+ // set up wavefunction part
782+ for (int ik = 0 ; ik < nk; ++ik)
783+ {
784+ Element& wf_elem = X.GetElement (ik);
785+
786+ double * raw_data = wf_elem.ObtainWriteEntireData ();
787+ const bool is_complex = wf_elem.Getiscomplex ();
788+ const int element_length = wf_elem.Getlength ();
789+ std::fill (raw_data, raw_data + element_length, 0.0 );
790+ std::complex <double >* complex_buffer = reinterpret_cast <std::complex <double >*>(raw_data);
791+
792+ psi::Psi<TK> * wfc = p_esolver->get_psi ();
793+ wfc->fix_k (ik);
794+
795+ for (int ib_local = 0 ; ib_local < nbands; ++ib_local)
796+ {
797+ wfc->fix_b (ib_local);
798+
799+ for (int ir_local = 0 ; ir_local < nbasis; ++ir_local)
800+ {
801+
802+ const int linear_index = ir_local + nbasis * ib_local;
803+ const TK value = (*wfc)(ir_local);
804+
805+ if (is_complex)
806+ {
807+ complex_buffer[linear_index] = std::complex <double >(std::real (value), std::imag (value));
808+ }
809+ else
810+ {
811+ raw_data[linear_index] = static_cast <double >(std::real (value));
812+ }
813+ }
814+ }
815+ }
816+
771817
772818 // this->X = std::move(initial_variable);
773- this ->X .Print (" Initial variable for DirectMin:" );
819+ if (true ) // for debug
820+ this ->X .Print (" Initial variable for DirectMin:" );
774821 // }
775- // Implementation of initial guess setup
776822}
777823
778824
0 commit comments