@@ -181,12 +181,14 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
181181
182182 this ->set_dimension ();
183183
184- // setup_wd_division is not need to be covered in #ifdef __MPI, see its implementation
184+ // setup_2d_division is not need to be covered in #ifdef __MPI, see its implementation
185185 LR_Util::setup_2d_division (this ->paraMat_ , 1 , this ->nbasis , this ->nbasis );
186-
187- this ->paraMat_ .atom_begin_row = std::move (ks_sol.pv .atom_begin_row );
188- this ->paraMat_ .atom_begin_col = std::move (ks_sol.pv .atom_begin_col );
189- this ->paraMat_ .iat2iwt_ = ucell.get_iat2iwt ();
186+ this ->set_parallel_orbitals_band (this ->paraMat_ , this ->nbands );
187+ if (PARAM.inp .cal_force )
188+ {
189+ LR_Util::setup_2d_division (this ->paraMat_all_ , 1 , this ->nbasis , this ->nbasis );
190+ this ->set_parallel_orbitals_band (this ->paraMat_all_ , PARAM.inp .nbands );
191+ }
190192
191193 LR_Util::setup_2d_division (this ->paraC_ , 1 , this ->nbasis , this ->nbands
192194#ifdef __MPI
@@ -195,32 +197,42 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
195197 );
196198 auto move_gs = [&, this ]() -> void // move the ground state info
197199 {
198- this ->psi_ks = ks_sol.psi ;
200+ this ->psi_ks_all = ks_sol.psi ;
199201 ks_sol.psi = nullptr ;
200202 // only need the eigenvalues. the 'elecstates' of excited states is different from ground state.
201- this ->eig_ks = std::move (ks_sol.pelec ->ekb );
203+ this ->eig_ks_all = std::move (ks_sol.pelec ->ekb );
202204 };
205+ move_gs ();
206+ // allocate psi_ks and eig_ks in the [nocc, nvirt] window
203207#ifdef __MPI
204- if (this ->nbands == PARAM.inp .nbands ) { move_gs (); }
205- else // copy the part of ground state info according to paraC_
208+ this ->psi_ks = new psi::Psi<T>(this ->kv .get_nks (),
209+ this ->paraC_ .get_col_size (),
210+ this ->paraC_ .get_row_size (),
211+ this ->kv .ngk ,
212+ true );
213+ #else
214+ this ->psi_ks = new psi::Psi<T>(this ->kv .get_nks (), this ->nbands , this ->nbasis , this ->kv .ngk , true );
215+ #endif
216+ this ->eig_ks .create (this ->kv .get_nks (), this ->nbands );
217+ const int start_band = this ->nocc_max - *std::max_element (nocc.begin (), nocc.end ());
218+
219+ for (int ik = 0 ;ik < this ->kv .get_nks ();++ik)
206220 {
207- this ->psi_ks = new psi::Psi<T>(this ->kv .get_nks (),
208- this ->paraC_ .get_col_size (),
209- this ->paraC_ .get_row_size (),
210- this ->kv .ngk ,
211- true );
212- this ->eig_ks .create (this ->kv .get_nks (), this ->nbands );
213- const int start_band = this ->nocc_max - *std::max_element (nocc.begin (), nocc.end ());
214- for (int ik = 0 ;ik < this ->kv .get_nks ();++ik)
221+ // copy the KS orbitals in the [nocc, nvirt] window
222+ #ifdef __MPI
223+ Cpxgemr2d (this ->nbasis , this ->nbands , &(*this ->psi_ks_all )(ik, 0 , 0 ), 1 , start_band + 1 , ks_sol.pv .desc_wfc ,
224+ &(*this ->psi_ks )(ik, 0 , 0 ), 1 , 1 , this ->paraC_ .desc , this ->paraC_ .blacs_ctxt );
225+ #else
226+ for (int ib = 0 ;ib < this ->nbands ;++ib)
215227 {
216- Cpxgemr2d (this ->nbasis , this ->nbands , &(*ks_sol.psi )(ik, 0 , 0 ), 1 , start_band + 1 , ks_sol.pv .desc_wfc ,
217- &(*this ->psi_ks )(ik, 0 , 0 ), 1 , 1 , this ->paraC_ .desc , this ->paraC_ .blacs_ctxt );
218- for (int ib = 0 ;ib < this ->nbands ;++ib) { this ->eig_ks (ik, ib) = ks_sol.pelec ->ekb (ik, start_band + ib); }
228+ auto * start = &(*this ->psi_ks_all )(ik, start_band + ib, 0 );
229+ auto * to = &(*this ->psi_ks )(ik, ib, 0 );
219230 }
220- }
221- #else
222- move_gs ();
223231#endif
232+ // copy the KS bands in the [nocc, nvirt] window
233+ for (int ib = 0 ;ib < this ->nbands ;++ib) { this ->eig_ks (ik, ib) = this ->eig_ks_all (ik, start_band + ib); }
234+ }
235+
224236 if (nspin == 2 )
225237 {
226238 this ->nupdown = cal_nupdown_form_occ (ks_sol.pelec ->wg );
@@ -309,14 +321,12 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
309321 this ->set_dimension ();
310322 // setup 2d-block distribution for AO-matrix and KS wfc
311323 LR_Util::setup_2d_division (this ->paraMat_ , 1 , this ->nbasis , this ->nbasis );
312- #ifdef __MPI
313- this ->paraMat_ .set_desc_wfc_Eij (this ->nbasis , this ->nbands , paraMat_.get_row_size ());
314- int err = this ->paraMat_ .set_nloc_wfc_Eij (this ->nbands , GlobalV::ofs_running, GlobalV::ofs_warning);
315- if (input.ri_hartree_benchmark != " aims" ) { this ->paraMat_ .set_atomic_trace (ucell.get_iat2iwt (), ucell.nat , this ->nbasis ); }
316- #else
317- this ->paraMat_ .nrow_bands = this ->nbasis ;
318- this ->paraMat_ .ncol_bands = this ->nbands ;
319- #endif
324+ this ->set_parallel_orbitals_band (this ->paraMat_ , this ->nbands );
325+ if (PARAM.inp .cal_force )
326+ {
327+ LR_Util::setup_2d_division (this ->paraMat_all_ , 1 , this ->nbasis , this ->nbasis );
328+ this ->set_parallel_orbitals_band (this ->paraMat_all_ , PARAM.inp .nbands );
329+ }
320330
321331 // read the ground state info
322332 // now ModuleIO::read_wfc_nao needs `Parallel_Orbitals` and can only read all the bands
@@ -594,6 +604,18 @@ void LR::ESolver_LR<T, TR>::after_all_runners(UnitCell& ucell)
594604 if (PARAM.inp .cal_force ) { this ->cal_force (is); }
595605 }
596606}
607+ template <typename T, typename TR>
608+ void LR::ESolver_LR<T, TR>::set_parallel_orbitals_band(Parallel_Orbitals& pmat, const int nbands_in)
609+ {
610+ #ifdef __MPI
611+ pmat.set_desc_wfc_Eij (this ->nbasis , nbands_in, pmat.get_row_size ());
612+ int err = pmat.set_nloc_wfc_Eij (nbands_in, GlobalV::ofs_running, GlobalV::ofs_warning);
613+ if (input.ri_hartree_benchmark != " aims" ) { pmat.set_atomic_trace (ucell.get_iat2iwt (), ucell.nat , this ->nbasis ); }
614+ #else
615+ pmat.nrow_bands = this ->nbasis ;
616+ pmat.ncol_bands = nbands_in;
617+ #endif
618+ }
597619
598620template <typename T, typename TR>
599621void LR::ESolver_LR<T, TR>::setup_eigenvectors_X()
@@ -717,8 +739,17 @@ void LR::ESolver_LR<T, TR>::read_ks_wfc()
717739 /* skip_bands=*/ this ->nocc_max - this ->nocc_in )) {
718740 ModuleBase::WARNING_QUIT (" ESolver_LR" , " read ground-state wavefunction failed." );
719741 }
720- this ->eig_ks = std::move (this ->pelec ->ekb );
721- this ->wg_ks = std::move (this ->pelec ->wg );
742+
743+ if (PARAM.inp .cal_force )
744+ { // allocate psi_ks_all and eig_ks_all to read all the bands
745+ this ->psi_ks_all = new psi::Psi<T>(this ->kv .get_nks (), paraMat_all_.ncol_bands , paraMat_all_.get_row_size (), this ->kv .ngk , true );
746+ this ->eig_ks_all .create (this ->kv .get_nks (), PARAM.inp .nbands );
747+ this ->wg_ks_all .create (this ->kv .get_nks (), PARAM.inp .nbands );
748+ if (!ModuleIO::read_wfc_nao (PARAM.globalv .global_readin_dir , paraMat_all_, *this ->psi_ks_all , this ->wg_ks_all , this ->eig_ks_all ,/* skip_bands=*/ 0 ))
749+ {
750+ GlobalV::ofs_running << " Read in all the KS wavefunctions for force calculation. " << std::endl;
751+ }
752+ }
722753}
723754
724755template <typename T, typename TR>
0 commit comments