Skip to content

Commit 4f1be57

Browse files
authored
Fix the repeated initial guess and use diagonal precondition in LR::HSolver; Fix a segfault and non-hermitian in multi-k op_lr_exx (#5468)
* fix the initial guess * diag-precondition * fix segfault and non-hermitian in op_lr_exx
1 parent 9718c41 commit 4f1be57

File tree

8 files changed

+55
-30
lines changed

8 files changed

+55
-30
lines changed

source/module_lr/dm_trans/dm_trans_parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* X_
109109
// &beta, dm_trans[isk].data<std::complex<double>>(), &i1, &i1, pmat.desc);
110110

111111
// ============== [C_virt * X * C_occ^\dagger]^T=============
112-
// ============== = [C_occ^* * X^T * C_virt^T]^T=============
112+
// ============== = [C_occ^* * X^T * C_virt^T]=============
113113
// 1. X*C_occ^\dagger
114114
char transa = 'N';
115115
char transb = 'C';

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "module_base/scalapack_connector.h"
1717
#include "module_parameter/parameter.h"
1818
#include "module_lr/ri_benchmark/ri_benchmark.h"
19+
#include "module_lr/operator_casida/operator_lr_diag.h" // for precondition
1920

2021
#ifdef __EXX
2122
template<>
@@ -421,20 +422,29 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
421422
if (GlobalV::MY_RANK == 0) { assert(nst == LR_Util::write_value(efile(label), prec, e, nst)); }
422423
assert(nst * dim == LR_Util::write_value(vfile(label), prec, v, nst, dim));
423424
};
425+
std::vector<double> precondition(this->input.lr_solver == "lapack" ? 0 : nloc_per_band, 1.0);
424426
// allocate and initialize A matrix and density matrix
425427
if (openshell)
426428
{
429+
for (int is : {0, 1})
430+
{
431+
const int offset_is = is * this->paraX_[0].get_local_size();
432+
OperatorLRDiag<double> pre_op(this->eig_ks.c + is * nk * (nocc[0] + nvirt[0]), this->paraX_[is], this->nk, this->nocc[is], this->nvirt[is]);
433+
if (input.lr_solver != "lapack") { pre_op.act(1, offset_is, 1, precondition.data() + offset_is, precondition.data() + offset_is); }
434+
}
427435
std::cout << "Solving spin-conserving excitation for open-shell system." << std::endl;
428436
HamiltULR<T> hulr(xc_kernel, nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, orb_cutoff_, GlobalC::GridD, *this->psi_ks, this->eig_ks,
429437
#ifdef __EXX
430438
this->exx_lri, this->exx_info.info_global.hybrid_alpha,
431439
#endif
432440
this->gint_, this->pot, this->kv, this->paraX_, this->paraC_, this->paraMat_);
433-
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr);
441+
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr, precondition);
434442
if (input.out_wfc_lr) { write_states("openshell", this->pelec->ekb.c, this->X[0].template data<T>(), nloc_per_band, nstates); }
435443
}
436444
else
437445
{
446+
OperatorLRDiag<double> pre_op(this->eig_ks.c, this->paraX_[0], this->nk, this->nocc[0], this->nvirt[0]);
447+
if (input.lr_solver != "lapack") { pre_op.act(1, nloc_per_band, 1, precondition.data(), precondition.data()); }
438448
auto spin_types = std::vector<std::string>({ "singlet", "triplet" });
439449
for (int is = 0;is < nspin;++is)
440450
{
@@ -447,7 +457,7 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
447457
spin_types[is], input.ri_hartree_benchmark, (input.ri_hartree_benchmark == "aims" ? input.aims_nbasis : std::vector<int>({})));
448458
// solve the Casida equation
449459
LR::HSolver::solve(hlr, this->X[is].template data<T>(), nloc_per_band, nstates,
450-
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr/*,
460+
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr, precondition/*,
451461
!std::set<std::string>({ "hf", "hse" }).count(this->xc_kernel)*/); //whether the kernel is Hermitian
452462
if (input.out_wfc_lr) { write_states(spin_types[is], this->pelec->ekb.c + is * nstates, this->X[is].template data<T>(), nloc_per_band, nstates); }
453463
}
@@ -565,8 +575,9 @@ void LR::ESolver_LR<T, TR>::set_X_initial_guess()
565575
const int is_in_x = openshell ? 0 : is; // if openshell, spin-up and spin-down are put together
566576
if (px.in_this_processor(virt_global, occ_global))
567577
{
578+
const int xstart_pair = ik * px.get_local_size();
568579
const int ipair_loc = px.global2local_col(occ_global) * px.get_row_size() + px.global2local_row(virt_global);
569-
X[is_in_x].data<T>()[xstart_bs + ipair_loc] = (static_cast<T>(1.0) / static_cast<T>(nk));
580+
X[is_in_x].data<T>()[xstart_bs + xstart_pair + ipair_loc] = (static_cast<T>(1.0) / static_cast<T>(nk));
570581
}
571582
}
572583
}

source/module_lr/hsolver_lrtd.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ namespace LR
3232
double* eig,
3333
const std::string method,
3434
const Real<T>& diag_ethr, ///< threshold for diagonalization
35+
const std::vector<Real<T>>& precondition,
3536
const bool hermitian = true)
3637
{
3738
ModuleBase::TITLE("HSolverLR", "solve");
3839
const std::vector<std::string> spin_types = { "singlet", "triplet" };
3940
// note: if not TDA, the eigenvalues will be complex
4041
// then we will need a new constructor of DiagoDavid
4142

42-
// 1. allocate precondition and eigenvalue
43-
std::vector<Real<T>> precondition(dim);
43+
// 1. allocate eigenvalue
4444
std::vector<Real<T>> eigenvalue(nband); //nstates
4545
// 2. select the method
4646
#ifdef __MPI
@@ -67,9 +67,7 @@ namespace LR
6767
}
6868
else
6969
{
70-
// 3. set precondition and diagethr
71-
for (int i = 0; i < dim; ++i) { precondition[i] = static_cast<Real<T>>(1.0); }
72-
70+
// 3. set maxiter and funcs
7371
const int maxiter = hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX;
7472

7573
auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) {hm.hPsi(psi_in, hpsi, ld_psi, nvec);};
@@ -139,7 +137,8 @@ namespace LR
139137

140138
auto psi_tensor = ct::TensorMap(psi, ct::DataTypeToEnum<T>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband, dim }));
141139
auto eigen_tensor = ct::TensorMap(eigenvalue.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband }));
142-
auto precon_tensor = ct::TensorMap(precondition.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
140+
std::vector<Real<T>> precondition_(precondition); //since TensorMap does not support const pointer
141+
auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
143142
auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data<T>(), hpsi.data<T>(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);};
144143
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi)
145144
{ std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements()); };

source/module_lr/operator_casida/operator_lr_exx.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace LR
2727
void OperatorLREXX<double>::cal_DM_onebase(const int io, const int iv, const int ik) const
2828
{
2929
ModuleBase::TITLE("OperatorLREXX", "cal_DM_onebase");
30+
// NOTICE: DM_onebase will be passed into `cal_energy` interface and conjugated by "zdotc".
31+
// So the formula should be the same as RHS. instead of LHS of the A-matrix,
32+
// i.e. c1v · conj(c2o) · e^{-ik(R2-R1)}
3033
assert(ik == 0);
3134
for (auto cell : this->BvK_cells)
3235
{
@@ -42,8 +45,12 @@ namespace LR
4245
const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2];
4346
for (int iw1 = 0;iw1 < nw1;++iw1)
4447
for (int iw2 = 0;iw2 < nw2;++iw2)
45-
if (this->pmat.in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2)))
46-
D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2));
48+
{
49+
const int iwt1 = ucell.itiaiw2iwt(it1, ia1, iw1);
50+
const int iwt2 = ucell.itiaiw2iwt(it2, ia2, iw2);
51+
if (this->pmat.in_this_processor(iwt1, iwt2))
52+
D2d(iw1, iw2) = this->psi_ks_full(ik, io, iwt1) * this->psi_ks_full(ik, nocc + iv, iwt2);
53+
}
4754
}
4855
}
4956
}
@@ -52,6 +59,9 @@ namespace LR
5259
void OperatorLREXX<std::complex<double>>::cal_DM_onebase(const int io, const int iv, const int ik) const
5360
{
5461
ModuleBase::TITLE("OperatorLREXX", "cal_DM_onebase");
62+
// NOTICE: DM_onebase will be passed into `cal_energy` interface and conjugated by "zdotc".
63+
// So the formula should be the same as RHS. instead of LHS of the A-matrix,
64+
// i.e. c1v · conj(c2o) · e^{-ik(R2-R1)}
5565
for (auto cell : this->BvK_cells)
5666
{
5767
std::complex<double> frac = RI::Global_Func::convert<std::complex<double>>(std::exp(
@@ -68,8 +78,12 @@ namespace LR
6878
const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2];
6979
for (int iw1 = 0;iw1 < nw1;++iw1)
7080
for (int iw2 = 0;iw2 < nw2;++iw2)
71-
if (this->pmat.in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2)))
72-
D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2));
81+
{
82+
const int iwt1 = ucell.itiaiw2iwt(it1, ia1, iw1);
83+
const int iwt2 = ucell.itiaiw2iwt(it2, ia2, iw2);
84+
if (this->pmat.in_this_processor(iwt1, iwt2))
85+
D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, iwt2)) * this->psi_ks_full(ik, nocc + iv, iwt1);
86+
}
7387
}
7488
}
7589
}
@@ -78,9 +92,6 @@ namespace LR
7892
void OperatorLREXX<T>::act(const int nbands, const int nbasis, const int npol, const T* psi_in, T* hpsi, const int ngk_ik, const bool is_first_node)const
7993
{
8094
ModuleBase::TITLE("OperatorLREXX", "act");
81-
82-
const int& nk = this->kv.get_nks() / this->nspin;
83-
8495
// convert parallel info to LibRI interfaces
8596
std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge(this->pmat);
8697

@@ -120,12 +131,12 @@ namespace LR
120131
{
121132
const int xstart_bk = ik * pX.get_local_size();
122133
this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor)
123-
// LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
134+
// LR_Util::print_CV(Ds_onebase, "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
124135
const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
125136
lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]);
126137
if (this->pX.in_this_processor(iv, io))
127138
{
128-
hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
139+
hpsi[xstart_bk + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
129140
}
130141
}
131142
}

source/module_lr/operator_casida/operator_lr_exx.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace LR
3232
const Parallel_Orbitals& pmat_in,
3333
const double& alpha = 1.0,
3434
const std::vector<int>& aims_nbasis = {})
35-
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt),
35+
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt), nk(kv_in.get_nks() / nspin),
3636
psi_ks(psi_ks_in), DM_trans(DM_trans_in), exx_lri(exx_lri_in), kv(kv_in),
3737
pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in), alpha(alpha),
3838
aims_nbasis(aims_nbasis)
@@ -42,8 +42,11 @@ namespace LR
4242
this->is_first_node = false;
4343

4444
// reduce psi_ks for later use
45-
this->psi_ks_full.resize(this->kv.get_nks(), nocc + nvirt, this->naos);
46-
LR_Util::gather_2d_to_full(this->pc, this->psi_ks.get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, nocc + nvirt);
45+
this->psi_ks_full.resize(this->nk, nocc + nvirt, this->naos);
46+
for (int ik = 0;ik < nk;++ik)
47+
{
48+
LR_Util::gather_2d_to_full(this->pc, &this->psi_ks(ik, 0, 0), &this->psi_ks_full(ik, 0, 0), false, this->naos, nocc + nvirt);
49+
}
4750

4851
// get cells in BvK supercell
4952
const TC period = RI_Util::get_Born_vonKarmen_period(kv_in);
@@ -63,12 +66,13 @@ namespace LR
6366
const int ngk_ik = 0,
6467
const bool is_first_node = false) const override;
6568

66-
private:
69+
private:
6770
//global sizes
68-
const int& nspin;
69-
const int& naos;
70-
const int& nocc;
71-
const int& nvirt;
71+
const int nspin = 1;
72+
const int naos = 1;
73+
const int nocc = 1;
74+
const int nvirt = 1;
75+
const int nk = 1; ///< number of k-points
7276
const double alpha = 1.0; //(allow non-ref constant)
7377
const bool cal_dm_trans = false;
7478
const bool tdm_sym = false; ///< whether transition density matrix is symmetric
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
totexcitationenergyref 0.784471
1+
totexcitationenergyref 0.784274
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
totexcitationenergyref 2.641295
1+
totexcitationenergyref 2.641190
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
totexcitationenergyref 3.067979
1+
totexcitationenergyref 3.067058

0 commit comments

Comments
 (0)