Skip to content

Commit 2ca6c6c

Browse files
committed
fix segfault and non-hermitian in op_lr_exx
1 parent 4c883be commit 2ca6c6c

File tree

4 files changed

+34
-19
lines changed

4 files changed

+34
-19
lines changed

source/module_lr/dm_trans/dm_trans_parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ std::vector<container::Tensor> cal_dm_trans_pblas(const std::complex<double>* X_
109109
// &beta, dm_trans[isk].data<std::complex<double>>(), &i1, &i1, pmat.desc);
110110

111111
// ============== [C_virt * X * C_occ^\dagger]^T=============
112-
// ============== = [C_occ^* * X^T * C_virt^T]^T=============
112+
// ============== = [C_occ^* * X^T * C_virt^T]=============
113113
// 1. X*C_occ^\dagger
114114
char transa = 'N';
115115
char transb = 'C';

source/module_lr/operator_casida/operator_lr_exx.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ namespace LR
2727
void OperatorLREXX<double>::cal_DM_onebase(const int io, const int iv, const int ik) const
2828
{
2929
ModuleBase::TITLE("OperatorLREXX", "cal_DM_onebase");
30+
// NOTICE: DM_onebase will be passed into `cal_energy` interface and conjugated by "zdotc".
31+
// So the formula should be the same as RHS. instead of LHS of the A-matrix,
32+
// i.e. c1v · conj(c2o) · e^{-ik(R2-R1)}
3033
assert(ik == 0);
3134
for (auto cell : this->BvK_cells)
3235
{
@@ -42,8 +45,12 @@ namespace LR
4245
const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2];
4346
for (int iw1 = 0;iw1 < nw1;++iw1)
4447
for (int iw2 = 0;iw2 < nw2;++iw2)
45-
if (this->pmat.in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2)))
46-
D2d(iw1, iw2) = this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1)) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2));
48+
{
49+
const int iwt1 = ucell.itiaiw2iwt(it1, ia1, iw1);
50+
const int iwt2 = ucell.itiaiw2iwt(it2, ia2, iw2);
51+
if (this->pmat.in_this_processor(iwt1, iwt2))
52+
D2d(iw1, iw2) = this->psi_ks_full(ik, io, iwt1) * this->psi_ks_full(ik, nocc + iv, iwt2);
53+
}
4754
}
4855
}
4956
}
@@ -52,6 +59,9 @@ namespace LR
5259
void OperatorLREXX<std::complex<double>>::cal_DM_onebase(const int io, const int iv, const int ik) const
5360
{
5461
ModuleBase::TITLE("OperatorLREXX", "cal_DM_onebase");
62+
// NOTICE: DM_onebase will be passed into `cal_energy` interface and conjugated by "zdotc".
63+
// So the formula should be the same as RHS. instead of LHS of the A-matrix,
64+
// i.e. c1v · conj(c2o) · e^{-ik(R2-R1)}
5565
for (auto cell : this->BvK_cells)
5666
{
5767
std::complex<double> frac = RI::Global_Func::convert<std::complex<double>>(std::exp(
@@ -68,8 +78,12 @@ namespace LR
6878
const int nw2 = aims_nbasis.empty() ? ucell.atoms[it2].nw : aims_nbasis[it2];
6979
for (int iw1 = 0;iw1 < nw1;++iw1)
7080
for (int iw2 = 0;iw2 < nw2;++iw2)
71-
if (this->pmat.in_this_processor(ucell.itiaiw2iwt(it1, ia1, iw1), ucell.itiaiw2iwt(it2, ia2, iw2)))
72-
D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, ucell.itiaiw2iwt(it1, ia1, iw1))) * this->psi_ks_full(ik, nocc + iv, ucell.itiaiw2iwt(it2, ia2, iw2));
81+
{
82+
const int iwt1 = ucell.itiaiw2iwt(it1, ia1, iw1);
83+
const int iwt2 = ucell.itiaiw2iwt(it2, ia2, iw2);
84+
if (this->pmat.in_this_processor(iwt1, iwt2))
85+
D2d(iw1, iw2) = frac * std::conj(this->psi_ks_full(ik, io, iwt2)) * this->psi_ks_full(ik, nocc + iv, iwt1);
86+
}
7387
}
7488
}
7589
}
@@ -78,9 +92,6 @@ namespace LR
7892
void OperatorLREXX<T>::act(const int nbands, const int nbasis, const int npol, const T* psi_in, T* hpsi, const int ngk_ik, const bool is_first_node)const
7993
{
8094
ModuleBase::TITLE("OperatorLREXX", "act");
81-
82-
const int& nk = this->kv.get_nks() / this->nspin;
83-
8495
// convert parallel info to LibRI interfaces
8596
std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge(this->pmat);
8697

@@ -120,12 +131,12 @@ namespace LR
120131
{
121132
const int xstart_bk = ik * pX.get_local_size();
122133
this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor)
123-
// LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
134+
// LR_Util::print_CV(Ds_onebase, "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
124135
const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
125136
lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]);
126137
if (this->pX.in_this_processor(iv, io))
127138
{
128-
hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
139+
hpsi[xstart_bk + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
129140
}
130141
}
131142
}

source/module_lr/operator_casida/operator_lr_exx.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace LR
3232
const Parallel_Orbitals& pmat_in,
3333
const double& alpha = 1.0,
3434
const std::vector<int>& aims_nbasis = {})
35-
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt),
35+
: nspin(nspin), naos(naos), nocc(nocc), nvirt(nvirt), nk(kv_in.get_nks() / nspin),
3636
psi_ks(psi_ks_in), DM_trans(DM_trans_in), exx_lri(exx_lri_in), kv(kv_in),
3737
pX(pX_in), pc(pc_in), pmat(pmat_in), ucell(ucell_in), alpha(alpha),
3838
aims_nbasis(aims_nbasis)
@@ -42,8 +42,11 @@ namespace LR
4242
this->is_first_node = false;
4343

4444
// reduce psi_ks for later use
45-
this->psi_ks_full.resize(this->kv.get_nks(), nocc + nvirt, this->naos);
46-
LR_Util::gather_2d_to_full(this->pc, this->psi_ks.get_pointer(), this->psi_ks_full.get_pointer(), false, this->naos, nocc + nvirt);
45+
this->psi_ks_full.resize(this->nk, nocc + nvirt, this->naos);
46+
for (int ik = 0;ik < nk;++ik)
47+
{
48+
LR_Util::gather_2d_to_full(this->pc, &this->psi_ks(ik, 0, 0), &this->psi_ks_full(ik, 0, 0), false, this->naos, nocc + nvirt);
49+
}
4750

4851
// get cells in BvK supercell
4952
const TC period = RI_Util::get_Born_vonKarmen_period(kv_in);
@@ -63,12 +66,13 @@ namespace LR
6366
const int ngk_ik = 0,
6467
const bool is_first_node = false) const override;
6568

66-
private:
69+
private:
6770
//global sizes
68-
const int& nspin;
69-
const int& naos;
70-
const int& nocc;
71-
const int& nvirt;
71+
const int nspin = 1;
72+
const int naos = 1;
73+
const int nocc = 1;
74+
const int nvirt = 1;
75+
const int nk = 1; ///< number of k-points
7276
const double alpha = 1.0; //(allow non-ref constant)
7377
const bool cal_dm_trans = false;
7478
const bool tdm_sym = false; ///< whether transition density matrix is symmetric
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
totexcitationenergyref 3.067082
1+
totexcitationenergyref 3.067058

0 commit comments

Comments
 (0)