Skip to content

Commit cda658e

Browse files
committed
remove band-traverse and recover RI-benchmark
1 parent 80e64b2 commit cda658e

File tree

7 files changed

+76
-91
lines changed

7 files changed

+76
-91
lines changed

source/module_lr/hamilt_casida.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace LR
4343
#endif
4444
}
4545
// output Amat
46-
std::cout << "Full A matrix:" << std::endl;
46+
std::cout << "Full A matrix: (elements < 1e-10 is set to 0)" << std::endl;
4747
LR_Util::print_value(Amat_full.data(), nk * npairs, nk * npairs);
4848
return Amat_full;
4949
}

source/module_lr/hamilt_casida.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace LR
4646
if (ri_hartree_benchmark != "aims") { assert(aims_nbasis.empty()); }
4747
// always use nspin=1 for transition density matrix
4848
this->DM_trans = LR_Util::make_unique<elecstate::DensityMatrix<T, T>>(&pmat_in, 1, kv_in.kvec_d, nk);
49-
LR_Util::initialize_DMR(*this->DM_trans, pmat_in, ucell_in, gd_in, orb_cutoff);
49+
if (ri_hartree_benchmark == "none") { LR_Util::initialize_DMR(*this->DM_trans, pmat_in, ucell_in, gd_in, orb_cutoff); }
5050
// this->DM_trans->init_DMR(&gd_in, &ucell_in); // too large due to not restricted by orb_cutoff
5151

5252
// add the diag operator (the first one)

source/module_lr/operator_casida/operator_lr_diag.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,11 @@ namespace LR
4646
{
4747
ModuleBase::TITLE("OperatorLRDiag", "act");
4848
const int nlocal_ph = nk * pX.get_local_size(); // local size of particle-hole basis
49-
for (int ib = 0;ib < nbands;++ib)
50-
{
51-
const int ibstart = ib * nlocal_ph;
52-
hsolver::vector_mul_vector_op<T, Device>()(this->ctx,
53-
nk * pX.get_local_size(),
54-
hpsi + ibstart,
55-
psi_in + ibstart,
56-
this->eig_ks_diff.c);
57-
}
49+
hsolver::vector_mul_vector_op<T, Device>()(this->ctx,
50+
nk * pX.get_local_size(),
51+
hpsi,
52+
psi_in,
53+
this->eig_ks_diff.c);
5854
}
5955
private:
6056
const Parallel_2D& pX;

source/module_lr/operator_casida/operator_lr_exx.cpp

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -83,56 +83,54 @@ namespace LR
8383

8484
// convert parallel info to LibRI interfaces
8585
std::vector<std::tuple<std::set<TA>, std::set<TA>>> judge = RI_2D_Comm::get_2D_judge(this->pmat);
86-
for (int ib = 0;ib < nbands;++ib)
87-
{
88-
const int xstart_b = ib * nk * pX.get_local_size();
89-
// suppose Cs,Vs, have already been calculated in the ion-step of ground state
90-
// and DM_trans has been calculated in hPsi() outside.
9186

92-
// 1. set_Ds (once)
93-
// convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
94-
std::vector<std::vector<T>> DMk_trans_vector = this->DM_trans->get_DMK_vector();
95-
// assert(DMk_trans_vector.size() == nk);
96-
std::vector<const std::vector<T>*> DMk_trans_pointer(nk);
97-
for (int ik = 0;ik < nk;++ik) {DMk_trans_pointer[ik] = &DMk_trans_vector[ik];}
98-
// if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
99-
std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
100-
aims_nbasis.empty() ?
101-
RI_2D_Comm::split_m2D_ktoR<T>(this->kv, DMk_trans_pointer, this->pmat, 1)
102-
: RI_Benchmark::split_Ds(DMk_trans_vector, aims_nbasis, ucell); //0.5 will be multiplied
103-
// LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10);
104-
// 2. cal_Hs
105-
auto lri = this->exx_lri.lock();
87+
// suppose Cs,Vs, have already been calculated in the ion-step of ground state
88+
// and DM_trans has been calculated in hPsi() outside.
89+
90+
// 1. set_Ds (once)
91+
// convert to vector<T*> for the interface of RI_2D_Comm::split_m2D_ktoR (interface will be unified to ct::Tensor)
92+
std::vector<std::vector<T>> DMk_trans_vector = this->DM_trans->get_DMK_vector();
93+
// assert(DMk_trans_vector.size() == nk);
94+
std::vector<const std::vector<T>*> DMk_trans_pointer(nk);
95+
for (int ik = 0;ik < nk;++ik) { DMk_trans_pointer[ik] = &DMk_trans_vector[ik]; }
96+
// if multi-k, DM_trans(TR=double) -> Ds_trans(TR=T=complex<double>)
97+
std::vector<std::map<TA, std::map<TAC, RI::Tensor<T>>>> Ds_trans =
98+
aims_nbasis.empty() ?
99+
RI_2D_Comm::split_m2D_ktoR<T>(this->kv, DMk_trans_pointer, this->pmat, 1)
100+
: RI_Benchmark::split_Ds(DMk_trans_vector, aims_nbasis, ucell); //0.5 will be multiplied
101+
// LR_Util::print_CV(Ds_trans[0], "Ds_trans in OperatorLREXX", 1e-10);
102+
// 2. cal_Hs
103+
auto lri = this->exx_lri.lock();
106104

107-
// LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10);
108-
lri->exx_lri.set_Ds(std::move(Ds_trans[0]), lri->info.dm_threshold);
109-
lri->exx_lri.cal_Hs();
110-
lri->Hexxs[0] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
111-
lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[0]), std::get<1>(judge[0]));
112-
lri->post_process_Hexx(lri->Hexxs[0]);
105+
// LR_Util::print_CV(Ds_trans[is], "Ds_trans in OperatorLREXX", 1e-10);
106+
lri->exx_lri.set_Ds(std::move(Ds_trans[0]), lri->info.dm_threshold);
107+
lri->exx_lri.cal_Hs();
108+
lri->Hexxs[0] = RI::Communicate_Tensors_Map_Judge::comm_map2_first(
109+
lri->mpi_comm, std::move(lri->exx_lri.Hs), std::get<0>(judge[0]), std::get<1>(judge[0]));
110+
lri->post_process_Hexx(lri->Hexxs[0]);
113111

114-
// 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point
115-
// caution: parrallel
112+
// 3. set [AX]_iak = DM_onbase * Hexxs for each occ-virt pair and each k-point
113+
// caution: parrallel
116114

117-
for (int io = 0;io < this->nocc;++io)
115+
for (int io = 0;io < this->nocc;++io)
116+
{
117+
for (int iv = 0;iv < this->nvirt;++iv)
118118
{
119-
for (int iv = 0;iv < this->nvirt;++iv)
119+
for (int ik = 0;ik < nk;++ik)
120120
{
121-
for (int ik = 0;ik < nk;++ik)
121+
const int xstart_bk = ik * pX.get_local_size();
122+
this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor)
123+
// LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
124+
const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
125+
lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]);
126+
if (this->pX.in_this_processor(iv, io))
122127
{
123-
const int xstart_bk = xstart_b + ik * pX.get_local_size();
124-
this->cal_DM_onebase(io, iv, ik); //set Ds_onebase for all e-h pairs (not only on this processor)
125-
// LR_Util::print_CV(Ds_onebase[is], "Ds_onebase of occ " + std::to_string(io) + ", virtual " + std::to_string(iv) + " in OperatorLREXX", 1e-10);
126-
const T& ene = 2 * alpha * //minus for exchange(but here plus is right, why?), 2 for Hartree to Ry
127-
lri->exx_lri.post_2D.cal_energy(this->Ds_onebase, lri->Hexxs[0]);
128-
if (this->pX.in_this_processor(iv, io))
129-
{
130-
hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
131-
}
128+
hpsi[xstart_bk + ik * pX.get_local_size() + this->pX.global2local_col(io) * this->pX.get_row_size() + this->pX.global2local_row(iv)] += ene;
132129
}
133130
}
134131
}
135132
}
133+
136134
}
137135
template class OperatorLREXX<double>;
138136
template class OperatorLREXX<std::complex<double>>;

source/module_lr/operator_casida/operator_lr_hxc.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,31 @@ namespace LR
2222
ModuleBase::TITLE("OperatorLRHxc", "act");
2323
const int& sl = ispin_ks[0];
2424
const auto psil_ks = LR_Util::get_psi_spin(psi_ks, sl, nk);
25-
2625
const int& lgd = gint->gridt->lgd;
27-
for (int ib = 0;ib < nbands;++ib)
28-
{
29-
const int xstart_b = ib * nbasis;
30-
31-
this->DM_trans->cal_DMR(); //DM_trans->get_DMR_vector() is 2d-block parallized
32-
// LR_Util::print_DMR(*DM_trans, ucell.nat, "DMR");
33-
34-
// ========================= begin grid calculation=========================
35-
this->grid_calculation(nbands); //DM(R) to H(R)
36-
// ========================= end grid calculation =========================
37-
38-
// V(R)->V(k)
39-
std::vector<ct::Tensor> v_hxc_2d(nk, LR_Util::newTensor<T>({ pmat.get_col_size(), pmat.get_row_size() }));
40-
for (auto& v : v_hxc_2d) v.zero();
41-
int nrow = ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver) ? this->pmat.get_row_size() : this->pmat.get_col_size();
42-
for (int ik = 0;ik < nk;++ik) { folding_HR(*this->hR, v_hxc_2d[ik].data<T>(), this->kv.kvec_d[ik], nrow, 1); } // V(R) -> V(k)
43-
// LR_Util::print_HR(*this->hR, this->ucell.nat, "4.VR");
44-
// if (this->first_print)
45-
// for (int ik = 0;ik < nk;++ik)
46-
// LR_Util::print_tensor<T>(v_hxc_2d[ik], "4.V(k)[ik=" + std::to_string(ik) + "]", &this->pmat);
47-
48-
// 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i}
26+
27+
this->DM_trans->cal_DMR(); //DM_trans->get_DMR_vector() is 2d-block parallized
28+
// LR_Util::print_DMR(*DM_trans, ucell.nat, "DMR");
29+
30+
// ========================= begin grid calculation=========================
31+
this->grid_calculation(nbands); //DM(R) to H(R)
32+
// ========================= end grid calculation =========================
33+
34+
// V(R)->V(k)
35+
std::vector<ct::Tensor> v_hxc_2d(nk, LR_Util::newTensor<T>({ pmat.get_col_size(), pmat.get_row_size() }));
36+
for (auto& v : v_hxc_2d) v.zero();
37+
int nrow = ModuleBase::GlobalFunc::IS_COLUMN_MAJOR_KS_SOLVER(PARAM.inp.ks_solver) ? this->pmat.get_row_size() : this->pmat.get_col_size();
38+
for (int ik = 0;ik < nk;++ik) { folding_HR(*this->hR, v_hxc_2d[ik].data<T>(), this->kv.kvec_d[ik], nrow, 1); } // V(R) -> V(k)
39+
// LR_Util::print_HR(*this->hR, this->ucell.nat, "4.VR");
40+
// if (this->first_print)
41+
// for (int ik = 0;ik < nk;++ik)
42+
// LR_Util::print_tensor<T>(v_hxc_2d[ik], "4.V(k)[ik=" + std::to_string(ik) + "]", &this->pmat);
43+
44+
// 5. [AX]^{Hxc}_{ai}=\sum_{\mu,\nu}c^*_{a,\mu,}V^{Hxc}_{\mu,\nu}c_{\nu,i}
4945
#ifdef __MPI
50-
cal_AX_pblas(v_hxc_2d, this->pmat, psil_ks, this->pc, naos, nocc[sl], nvirt[sl], this->pX[sl], hpsi + xstart_b);
46+
cal_AX_pblas(v_hxc_2d, this->pmat, psil_ks, this->pc, naos, nocc[sl], nvirt[sl], this->pX[sl], hpsi);
5147
#else
52-
cal_AX_blas(v_hxc_2d, psil_ks, nocc[sl], nvirt[sl], hpsi + xstart_b);
48+
cal_AX_blas(v_hxc_2d, psil_ks, nocc[sl], nvirt[sl], hpsi);
5349
#endif
54-
}
5550
}
5651

5752

source/module_lr/ri_benchmark/operator_ri_hartree.h

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,18 @@ namespace RI_Benchmark
5050
}
5151
};
5252
~OperatorRIHartree() {}
53-
void act(const psi::Psi<T>& X_in, psi::Psi<T>& X_out, const int nbands) const override
53+
void act(const int nbands, const int nbasis, const int npol, const T* psi_in, T* hpsi, const int ngk_ik = 0)const override
5454
{
5555
assert(GlobalV::MY_RANK == 0); // only serial now
56-
const int nk = 1;
57-
const psi::Psi<T>& X = LR_Util::k1_to_bfirst_wrapper(X_in, nk, npairs);
58-
psi::Psi<T> AX = LR_Util::k1_to_bfirst_wrapper(X_out, nk, npairs);
59-
for (int ib = 0;ib < nbands;++ib)
60-
{
61-
TLRIX<T> CsX_vo = cal_CsX(Cs_vo_mo, &X(ib, 0, 0));
62-
TLRIX<T> CsX_ov = cal_CsX(Cs_ov_mo, &X(ib, 0, 0));
63-
// LR_Util::print_CsX(Cs_bX, nvirt, "Cs_bX of state " + std::to_string(ib));
64-
cal_AX(CV_vo, CsX_vo, &AX(ib, 0, 0), 4.);
65-
cal_AX(CV_vo, CsX_ov, &AX(ib, 0, 0), 4.);
66-
cal_AX(CV_ov, CsX_vo, &AX(ib, 0, 0), 4.);
67-
cal_AX(CV_ov, CsX_ov, &AX(ib, 0, 0), 4.);
68-
}
56+
assert(nbasis == npairs);
57+
TLRIX<T> CsX_vo = cal_CsX(Cs_vo_mo, psi_in);
58+
TLRIX<T> CsX_ov = cal_CsX(Cs_ov_mo, psi_in);
59+
// LR_Util::print_CsX(Cs_bX, nvirt, "Cs_bX of state " + std::to_string(ib));
60+
// 4 for 4 terms in the expansion of local RI
61+
cal_AX(CV_vo, CsX_vo, hpsi, 4.);
62+
cal_AX(CV_vo, CsX_ov, hpsi, 4.);
63+
cal_AX(CV_ov, CsX_vo, hpsi, 4.);
64+
cal_AX(CV_ov, CsX_ov, hpsi, 4.);
6965
}
7066
protected:
7167
const int& naos;

source/module_lr/ri_benchmark/ri_benchmark.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ namespace RI_Benchmark
153153
return Amat_full;
154154
}
155155
template <typename TK>
156-
TLRIX<TK> cal_CsX(const TLRI<TK>& Cs_mo, TK* X)
156+
TLRIX<TK> cal_CsX(const TLRI<TK>& Cs_mo, const TK* X)
157157
{
158158
TLRIX<TK> CsX;
159159
for (auto& it1 : Cs_mo)

0 commit comments

Comments
 (0)