Skip to content

Commit 0551689

Browse files
committed
fix: reproducing gs-dm and gs-force
1 parent 67b435d commit 0551689

File tree

7 files changed

+97
-53
lines changed

7 files changed

+97
-53
lines changed

source/module_lr/Grad/esolver_lr_grad.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ ct::Tensor LR::ESolver_LR<T, TR>::solve_zvector_eqation(const int ispin)
120120
template<typename T, typename TR>
121121
std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin)
122122
{
123-
if (PARAM.inp.test_force) { this->test_force(ispin); }
123+
if (PARAM.inp.test_force && ispin == 0) { this->test_force(); }
124124

125125
const ct::Tensor& Z = this->solve_zvector_eqation(ispin);
126126

@@ -130,7 +130,7 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
130130

131131
// calculate the force (the partial gradient of Lagrangian)
132132
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_);
133-
GlobalV::ofs_running << "Start to calculate excited-state force of " << this->spin_types[ispin] << std::endl;
133+
GlobalV::ofs_running << "Start to calculate excited-state force of " << this->spin_types[ispin] << std::endl;
134134
// ground state dm for currrent spin (only for test the correctness of the force)
135135
// elecstate::DensityMatrix<T, T> dm_gs(this->paraMat_, 1, this->kv.kvec_d, this->nk);
136136

@@ -199,7 +199,7 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
199199
ModuleBase::matrix force_hxc_dmtrans = lr_force.cal_force_hxc_dmtrans(dm_trans_real, *this->pot[ispin]);
200200
std::cout << "Force (Hxc-DMTrans term) of state " << istate << ": " << std::endl;
201201
LR_Util::print_value(force_hxc_dmtrans.c, ucell.nat, 3);
202-
ModuleBase::matrix force_hamiltgs_relaxed_diff = lr_force.cal_force_hamilt_gs_dm_relaxed_diff(relaxed_diff_dm_real, *pot_gs);
202+
ModuleBase::matrix force_hamiltgs_relaxed_diff = lr_force.cal_force_hamilt_gs_dm_relaxed_diff(relaxed_diff_dm_real, *pot_gs, /*with_ewald=*/false);
203203
std::cout << "Force (GS-(T+Z) term) of state " << istate << ": " << std::endl;
204204
LR_Util::print_value(force_hamiltgs_relaxed_diff.c, ucell.nat, 3);
205205
ModuleBase::matrix force_overlap_edm = lr_force.cal_force_overlap_edm(edm_real); // "-" sign has been included in the force factor
@@ -215,21 +215,22 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
215215
}
216216

217217
template<typename T, typename TR>
218-
void LR::ESolver_LR<T, TR>::test_force(const int ispin)
218+
void LR::ESolver_LR<T, TR>::test_force()
219219
{
220220
LR_Force<T> lr_force(this->ucell, this->kv.kvec_d, this->paraMat_, *this->pw_rho, this->locpp, this->sf, this->gd, this->gint_, this->two_center_bundle_);
221221

222222
elecstate::DensityMatrix<T, double> dm_gs(&this->paraMat_, this->nspin, this->kv.kvec_d, this->nk); //DX
223-
elecstate::cal_dm_psi(&this->paraMat_, this->wg_ks, *this->psi_ks, dm_gs);
224-
LR_Util::initialize_DMR(dm_gs, this->paraMat_, this->ucell, this->gd, this->orb_cutoff_);
223+
elecstate::cal_dm_psi(&this->paraMat_all_, this->wg_ks_all, *this->psi_ks_all, dm_gs); // nbands is important here
224+
LR_Util::initialize_DMR(dm_gs, this->paraMat_, this->ucell, this->gd, this->orb_cutoff_); // nbands is not important here
225225
dm_gs.cal_DMR();
226-
226+
LR_Util::print_DMR(dm_gs, this->ucell.nat, "DM(R) of ground state");
227227
///========================== test 1: reproduce the force of ground state =========================
228228
// energy density matrix of the ground state
229229
elecstate::DensityMatrix<T, double> edm_gs(&this->paraMat_, this->nspin, this->kv.kvec_d, this->nk); //DX
230-
ModuleBase::matrix wg_ekb_ks(nspin, nbands);
231-
std::transform(this->wg_ks.c, this->wg_ks.c + nspin * nbands, this->eig_ks.c, wg_ekb_ks.c, std::multiplies<double>());
232-
elecstate::cal_dm_psi(&this->paraMat_, wg_ekb_ks, *this->psi_ks, edm_gs);
230+
ModuleBase::matrix wg_ekb_ks_all(nspin, PARAM.inp.nbands);
231+
std::transform(this->wg_ks_all.c, this->wg_ks_all.c + nspin * PARAM.inp.nbands,
232+
this->eig_ks_all.c, wg_ekb_ks_all.c, std::multiplies<double>());
233+
elecstate::cal_dm_psi(&this->paraMat_all_, wg_ekb_ks_all, *this->psi_ks_all, edm_gs);
233234
LR_Util::initialize_DMR(edm_gs, this->paraMat_, this->ucell, this->gd, this->orb_cutoff_);
234235
edm_gs.cal_DMR();
235236
// ground-state force
@@ -239,7 +240,7 @@ void LR::ESolver_LR<T, TR>::test_force(const int ispin)
239240
///========================== test 2: reproduce the DX Hartree term =========================
240241
ModuleBase::matrix f_hxc_potgs = lr_force.reproduce_force_gs_loc(dm_gs, *this->pot_gs_hartree);
241242
ModuleIO::print_force(GlobalV::ofs_running, this->ucell, "GS Hartree force calculated by 'cal_pulay_fs' from potential (eV/Angstrom)", f_hxc_potgs, false);
242-
ModuleBase::matrix f_hxc_potlr = lr_force.cal_force_hxc_dmtrans(dm_gs, *this->pot[ispin]);
243+
ModuleBase::matrix f_hxc_potlr = lr_force.cal_force_hxc_dmtrans(dm_gs, *this->pot[0]);
243244
ModuleIO::print_force(GlobalV::ofs_running, this->ucell, "2* GS Hxc force calculated by 'LR_Force' from kernel (eV/Angstrom)", f_hxc_potlr * 2, false);
244245
// 2 for spin in f->v. Spin in v->f is already multiplied in the singlet Hartree factor 2.
245246
/// ======================================= END test 2 =========================================

source/module_lr/Grad/force/lr_force.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ namespace LR
2626
}
2727

2828
template<typename TK>
29-
ModuleBase::matrix LR_Force<TK>::cal_force_hamilt_gs_dm_relaxed_diff(const elecstate::DensityMatrix<TK, double>& relax_diff_dm, const elecstate::Potential& pot_gs)
29+
ModuleBase::matrix LR_Force<TK>::cal_force_hamilt_gs_dm_relaxed_diff(const elecstate::DensityMatrix<TK, double>& relax_diff_dm,
30+
const elecstate::Potential& pot_gs, const bool with_ewald)
3031
{
3132
const Charge chr_diff_relaxed = dm_to_charge(relax_diff_dm);
3233

3334
// 1. local pp (Hellmann-Feynman)(fvl_dvl) + ewald + core correction (+ self-consistent charge)
34-
ModuleBase::matrix f_pw = ForcePWTerms<double>()(this->ucell_, chr_diff_relaxed, this->rhopw_, this->locpp_, this->sf_, /*with_ewald=*/false);
35+
ModuleBase::matrix f_pw = ForcePWTerms<double>()(this->ucell_, chr_diff_relaxed, this->rhopw_, this->locpp_, this->sf_, with_ewald);
3536

3637
// 2. nonlocal pp (Hellmann-Feynman + Pulay)
3738
ModuleBase::matrix fvnl = cal_force_nonlocal(this->ucell_, this->kvec_d_, this->gd_, this->two_center_bundle_, relax_diff_dm);

source/module_lr/Grad/force/lr_force.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ namespace LR
2525
}
2626

2727
/// 1. $Tr[H_{GS}^x * (T+D^Z)]$, where GS=groud state and $(T+D^Z)$ is the relaxed difference density matrix
28-
ModuleBase::matrix cal_force_hamilt_gs_dm_relaxed_diff(const elecstate::DensityMatrix<TK, double>& relaxed_diff_dm, const elecstate::Potential& pot_gs);
28+
ModuleBase::matrix cal_force_hamilt_gs_dm_relaxed_diff(const elecstate::DensityMatrix<TK, double>& relaxed_diff_dm,
29+
const elecstate::Potential& pot_gs, const bool with_ewald = true);
2930

3031
/// 2. $Tr[S^x * (EDM)]
3132
ModuleBase::matrix cal_force_overlap_edm(const elecstate::DensityMatrix<TK, double>& edm);

source/module_lr/Grad/multipliers/hamilt_zeq_right.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ namespace LR
103103
for (int ib = 0;ib < nband;++ib)
104104
{
105105
const int offset = ib * ld_psi;
106-
this->cal_dm_trans(0, psi + offset); // transition density matrix
106+
// this->cal_dm_trans(0, psi + offset); // transition density matrix, only for test
107107
this->cal_dm_diff(0, psi + offset); // difference density matrix
108108
hamilt::Operator<T>* node(this->ops);
109109
while (node != nullptr)

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 64 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,14 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
181181

182182
this->set_dimension();
183183

184-
// setup_wd_division is not need to be covered in #ifdef __MPI, see its implementation
184+
// setup_2d_division is not need to be covered in #ifdef __MPI, see its implementation
185185
LR_Util::setup_2d_division(this->paraMat_, 1, this->nbasis, this->nbasis);
186-
187-
this->paraMat_.atom_begin_row = std::move(ks_sol.pv.atom_begin_row);
188-
this->paraMat_.atom_begin_col = std::move(ks_sol.pv.atom_begin_col);
189-
this->paraMat_.iat2iwt_ = ucell.get_iat2iwt();
186+
this->set_parallel_orbitals_band(this->paraMat_, this->nbands);
187+
if (PARAM.inp.cal_force)
188+
{
189+
LR_Util::setup_2d_division(this->paraMat_all_, 1, this->nbasis, this->nbasis);
190+
this->set_parallel_orbitals_band(this->paraMat_all_, PARAM.inp.nbands);
191+
}
190192

191193
LR_Util::setup_2d_division(this->paraC_, 1, this->nbasis, this->nbands
192194
#ifdef __MPI
@@ -195,32 +197,42 @@ LR::ESolver_LR<T, TR>::ESolver_LR(ModuleESolver::ESolver_KS_LCAO<T, TR>&& ks_sol
195197
);
196198
auto move_gs = [&, this]() -> void // move the ground state info
197199
{
198-
this->psi_ks = ks_sol.psi;
200+
this->psi_ks_all = ks_sol.psi;
199201
ks_sol.psi = nullptr;
200202
//only need the eigenvalues. the 'elecstates' of excited states is different from ground state.
201-
this->eig_ks = std::move(ks_sol.pelec->ekb);
203+
this->eig_ks_all = std::move(ks_sol.pelec->ekb);
202204
};
205+
move_gs();
206+
// allocate psi_ks and eig_ks in the [nocc, nvirt] window
203207
#ifdef __MPI
204-
if (this->nbands == PARAM.inp.nbands) { move_gs(); }
205-
else // copy the part of ground state info according to paraC_
208+
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(),
209+
this->paraC_.get_col_size(),
210+
this->paraC_.get_row_size(),
211+
this->kv.ngk,
212+
true);
213+
#else
214+
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(), this->nbands, this->nbasis, this->kv.ngk, true);
215+
#endif
216+
this->eig_ks.create(this->kv.get_nks(), this->nbands);
217+
const int start_band = this->nocc_max - *std::max_element(nocc.begin(), nocc.end());
218+
219+
for (int ik = 0;ik < this->kv.get_nks();++ik)
206220
{
207-
this->psi_ks = new psi::Psi<T>(this->kv.get_nks(),
208-
this->paraC_.get_col_size(),
209-
this->paraC_.get_row_size(),
210-
this->kv.ngk,
211-
true);
212-
this->eig_ks.create(this->kv.get_nks(), this->nbands);
213-
const int start_band = this->nocc_max - *std::max_element(nocc.begin(), nocc.end());
214-
for (int ik = 0;ik < this->kv.get_nks();++ik)
221+
// copy the KS orbitals in the [nocc, nvirt] window
222+
#ifdef __MPI
223+
Cpxgemr2d(this->nbasis, this->nbands, &(*this->psi_ks_all)(ik, 0, 0), 1, start_band + 1, ks_sol.pv.desc_wfc,
224+
&(*this->psi_ks)(ik, 0, 0), 1, 1, this->paraC_.desc, this->paraC_.blacs_ctxt);
225+
#else
226+
for (int ib = 0;ib < this->nbands;++ib)
215227
{
216-
Cpxgemr2d(this->nbasis, this->nbands, &(*ks_sol.psi)(ik, 0, 0), 1, start_band + 1, ks_sol.pv.desc_wfc,
217-
&(*this->psi_ks)(ik, 0, 0), 1, 1, this->paraC_.desc, this->paraC_.blacs_ctxt);
218-
for (int ib = 0;ib < this->nbands;++ib) { this->eig_ks(ik, ib) = ks_sol.pelec->ekb(ik, start_band + ib); }
228+
auto* start = &(*this->psi_ks_all)(ik, start_band + ib, 0);
229+
auto* to = &(*this->psi_ks)(ik, ib, 0);
219230
}
220-
}
221-
#else
222-
move_gs();
223231
#endif
232+
// copy the KS bands in the [nocc, nvirt] window
233+
for (int ib = 0;ib < this->nbands;++ib) { this->eig_ks(ik, ib) = this->eig_ks_all(ik, start_band + ib); }
234+
}
235+
224236
if (nspin == 2)
225237
{
226238
this->nupdown = cal_nupdown_form_occ(ks_sol.pelec->wg);
@@ -309,14 +321,12 @@ LR::ESolver_LR<T, TR>::ESolver_LR(const Input_para& inp, UnitCell& ucell) : inpu
309321
this->set_dimension();
310322
// setup 2d-block distribution for AO-matrix and KS wfc
311323
LR_Util::setup_2d_division(this->paraMat_, 1, this->nbasis, this->nbasis);
312-
#ifdef __MPI
313-
this->paraMat_.set_desc_wfc_Eij(this->nbasis, this->nbands, paraMat_.get_row_size());
314-
int err = this->paraMat_.set_nloc_wfc_Eij(this->nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
315-
if (input.ri_hartree_benchmark != "aims") { this->paraMat_.set_atomic_trace(ucell.get_iat2iwt(), ucell.nat, this->nbasis); }
316-
#else
317-
this->paraMat_.nrow_bands = this->nbasis;
318-
this->paraMat_.ncol_bands = this->nbands;
319-
#endif
324+
this->set_parallel_orbitals_band(this->paraMat_, this->nbands);
325+
if (PARAM.inp.cal_force)
326+
{
327+
LR_Util::setup_2d_division(this->paraMat_all_, 1, this->nbasis, this->nbasis);
328+
this->set_parallel_orbitals_band(this->paraMat_all_, PARAM.inp.nbands);
329+
}
320330

321331
// read the ground state info
322332
// now ModuleIO::read_wfc_nao needs `Parallel_Orbitals` and can only read all the bands
@@ -594,6 +604,18 @@ void LR::ESolver_LR<T, TR>::after_all_runners(UnitCell& ucell)
594604
if (PARAM.inp.cal_force) { this->cal_force(is); }
595605
}
596606
}
607+
template<typename T, typename TR>
608+
void LR::ESolver_LR<T, TR>::set_parallel_orbitals_band(Parallel_Orbitals& pmat, const int nbands_in)
609+
{
610+
#ifdef __MPI
611+
pmat.set_desc_wfc_Eij(this->nbasis, nbands_in, pmat.get_row_size());
612+
int err = pmat.set_nloc_wfc_Eij(nbands_in, GlobalV::ofs_running, GlobalV::ofs_warning);
613+
if (input.ri_hartree_benchmark != "aims") { pmat.set_atomic_trace(ucell.get_iat2iwt(), ucell.nat, this->nbasis); }
614+
#else
615+
pmat.nrow_bands = this->nbasis;
616+
pmat.ncol_bands = nbands_in;
617+
#endif
618+
}
597619

598620
template<typename T, typename TR>
599621
void LR::ESolver_LR<T, TR>::setup_eigenvectors_X()
@@ -717,8 +739,17 @@ void LR::ESolver_LR<T, TR>::read_ks_wfc()
717739
/*skip_bands=*/this->nocc_max - this->nocc_in)) {
718740
ModuleBase::WARNING_QUIT("ESolver_LR", "read ground-state wavefunction failed.");
719741
}
720-
this->eig_ks = std::move(this->pelec->ekb);
721-
this->wg_ks = std::move(this->pelec->wg);
742+
743+
if (PARAM.inp.cal_force)
744+
{ // allocate psi_ks_all and eig_ks_all to read all the bands
745+
this->psi_ks_all = new psi::Psi<T>(this->kv.get_nks(), paraMat_all_.ncol_bands, paraMat_all_.get_row_size(), this->kv.ngk, true);
746+
this->eig_ks_all.create(this->kv.get_nks(), PARAM.inp.nbands);
747+
this->wg_ks_all.create(this->kv.get_nks(), PARAM.inp.nbands);
748+
if (!ModuleIO::read_wfc_nao(PARAM.globalv.global_readin_dir, paraMat_all_, *this->psi_ks_all, this->wg_ks_all, this->eig_ks_all,/*skip_bands=*/0))
749+
{
750+
GlobalV::ofs_running << " Read in all the KS wavefunctions for force calculation. " << std::endl;
751+
}
752+
}
722753
}
723754

724755
template<typename T, typename TR>

source/module_lr/esolver_lrtd_lcao.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ namespace LR
3535
ESolver_LR(const Input_para& inp, UnitCell& ucell);
3636
~ESolver_LR() {
3737
delete this->psi_ks;
38+
delete this->psi_ks_all;
3839
}
3940

4041
///input: input, call, basis(LCAO), psi(ground state), elecstate
@@ -61,17 +62,22 @@ namespace LR
6162
// ground state info
6263

6364
/// @brief ground state wave function
64-
psi::Psi<T>* psi_ks = nullptr;
65+
psi::Psi<T>* psi_ks = nullptr; ///< KS orbitals used in the [nocc+nvirt] window
66+
psi::Psi<T>* psi_ks_all = nullptr; ///< all KS orbitals, read from the file, or moved from ESolver_FP::pelec.psi
6567

6668
/// @brief ground state bands, read from the file, or moved from ESolver_FP::pelec.ekb
67-
ModuleBase::matrix eig_ks;///< energy of ground state
69+
ModuleBase::matrix eig_ks;///< ground state eigenvalues in the [nocc+nvirt] window
70+
ModuleBase::matrix eig_ks_all; ///< all eigenvalues of ground state, read from the file, or moved from ESolver_FP::pelec.ekb
71+
ModuleBase::matrix wg_ks; /// occupation numbers of ground state in the [nocc+nvirt] window
72+
ModuleBase::matrix wg_ks_all; /// occupation number of all bands of ground state
73+
6874

6975
// @brief only needed for force calculation
7076
std::unique_ptr<elecstate::Potential> pot_gs;
7177
std::unique_ptr<elecstate::Potential> pot_gs_hartree; /// ground-state Hartree potential, only used for test_force
7278
double etxc_gs = 0.;
7379
double vtxc_gs = 0.;
74-
ModuleBase::matrix wg_ks; /// occupation number of ground state
80+
7581
std::shared_ptr<PotHxcLR> pot_hxc_gs; /// used in lr-grad, in the ground-state Hxc gradient term coming from dF/dC
7682

7783
/// @brief Excited state wavefunction (locc, lvirt are local size of nocc and nvirt in each process)
@@ -112,6 +118,7 @@ namespace LR
112118
std::vector<Parallel_2D> paraX_;
113119
/// @brief variables for parallel distribution of matrix in AO representation
114120
Parallel_Orbitals paraMat_;
121+
Parallel_Orbitals paraMat_all_; // for the parallelized size of the KS orbitals
115122

116123
TwoCenterBundle two_center_bundle_;
117124

@@ -133,11 +140,14 @@ namespace LR
133140
/// reset nocc, nvirt, npairs after read ground-state wavefunction when nspin=2
134141
void reset_dim_spin2();
135142

143+
/// setup Parallel_Orbitals info. beyond Parallel_2D
144+
void set_parallel_orbitals_band(Parallel_Orbitals& p, const int nbands_in);
145+
136146
///========================== for gradient calculation =========================
137147
void init_pot_groundstate(const Charge& chg_gs);
138148
ct::Tensor solve_zvector_eqation(const int ispin);
139149
std::vector<ModuleBase::matrix> cal_force(const int ispin);
140-
void test_force(const int ispin); // test: reproduce the force of ground state
150+
void test_force(); // test: reproduce the force of ground state
141151

142152
#ifdef __EXX
143153
/// Tdata of Exx_LRI is same as T, for the reason, see operator_lr_exx.h

source/module_lr/utils/lr_util_hcontainer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace LR_Util
4242
std::cout << label << "\n";
4343
int is = 0;
4444
for (auto& dr : DMR.get_DMR_vector())
45-
print_HR(*dr, nat, "DMR[" + std::to_string(is++) + "]", threshold);
45+
print_HR(*dr, nat, "DMR[ispin=s" + std::to_string(is++) + "]", threshold);
4646
}
4747

4848
template<typename T>

0 commit comments

Comments
 (0)