Skip to content

Commit 3aab0de

Browse files
committed
Refactore psiDotPsi()
1 parent a61cc4b commit 3aab0de

File tree

4 files changed

+88
-6
lines changed

4 files changed

+88
-6
lines changed

abaInstall_HZWpara.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
rm -rf build
22

3-
cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
4-
#cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DMPI_CXX_COMPILER=mpiicpc -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DLibxc_DIR=/public1/home/t6s000394/jghan/software/libxc/ -DLIBRI_DIR=/public1/home/t6s000394/jghan/software/LibRI -DLIBCOMM_DIR=/public1/home/t6s000394/jghan/software/LibComm -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
3+
#cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
4+
cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DMPI_CXX_COMPILER=mpiicpc -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DLibxc_DIR=/public1/home/t6s000394/jghan/software/libxc/ -DLIBRI_DIR=/public1/home/t6s000394/jghan/software/LibRI -DLIBCOMM_DIR=/public1/home/t6s000394/jghan/software/LibComm -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
55

66
#cmake --build build -j 52 2>job.err
77
cmake --build build -j 92

source/module_rdmft/rdmft.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,12 @@ void RDMFT<TK, TR>::cal_Hk_Hpsi()
202202
HkPsi( ParaV, hsk_hartree->get_hk()[0], wfc(ik, 0, 0), H_wfc_hartree(ik, 0, 0));
203203

204204
// get wfc * H(k)_wfc
205-
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_TV(ik, 0, 0), Eij_TV, &(wfcHwfc_TV(ik, 0)) );
206-
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_hartree(ik, 0, 0), Eij_hartree, &(wfcHwfc_hartree(ik, 0)) );
205+
// psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_TV(ik, 0, 0), Eij_TV, &(wfcHwfc_TV(ik, 0)) );
206+
// psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_hartree(ik, 0, 0), Eij_hartree, &(wfcHwfc_hartree(ik, 0)) );
207+
cal_bra_op_ket( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_TV(ik, 0, 0), Eij_TV );
208+
cal_bra_op_ket( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_hartree(ik, 0, 0), Eij_hartree );
209+
_diagonal_in_serial( para_Eij, Eij_TV, &(wfcHwfc_TV(ik, 0)) );
210+
_diagonal_in_serial( para_Eij, Eij_hartree, &(wfcHwfc_hartree(ik, 0)) );
207211

208212
#ifdef __EXX
209213
if(GlobalC::exx_info.info_global.cal_exx)
@@ -212,7 +216,9 @@ void RDMFT<TK, TR>::cal_Hk_Hpsi()
212216

213217
V_exx_XC->contributeHk(ik);
214218
HkPsi( ParaV, hsk_exx_XC->get_hk()[0], wfc(ik, 0, 0), H_wfc_exx_XC(ik, 0, 0));
215-
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_exx_XC(ik, 0, 0), Eij_exx_XC, &(wfcHwfc_exx_XC(ik, 0)) );
219+
// psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_exx_XC(ik, 0, 0), Eij_exx_XC, &(wfcHwfc_exx_XC(ik, 0)) );
220+
cal_bra_op_ket( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_exx_XC(ik, 0, 0), Eij_exx_XC );
221+
_diagonal_in_serial( para_Eij, Eij_exx_XC, &(wfcHwfc_exx_XC(ik, 0)) );
216222

217223
for(int iloc=0; iloc<HK_XC.size(); ++iloc) HK_XC[iloc] += hsk_exx_XC->get_hk()[iloc];
218224
}
@@ -223,7 +229,9 @@ void RDMFT<TK, TR>::cal_Hk_Hpsi()
223229

224230
V_dft_XC->contributeHk(ik);
225231
HkPsi( ParaV, hsk_dft_XC->get_hk()[0], wfc(ik, 0, 0), H_wfc_dft_XC(ik, 0, 0));
226-
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_dft_XC(ik, 0, 0), Eij_exx_XC, &(wfcHwfc_dft_XC(ik, 0)) );
232+
// psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_dft_XC(ik, 0, 0), Eij_exx_XC, &(wfcHwfc_dft_XC(ik, 0)) );
233+
cal_bra_op_ket( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_dft_XC(ik, 0, 0), Eij_XC );
234+
_diagonal_in_serial( para_Eij, Eij_XC, &(wfcHwfc_dft_XC(ik, 0)) );
227235

228236
for(int iloc=0; iloc<HK_XC.size(); ++iloc) HK_XC[iloc] += hsk_dft_XC->get_hk()[iloc];
229237
}

source/module_rdmft/rdmft_tools.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,29 @@ void psiDotPsi<double>(const Parallel_Orbitals* ParaV, const Parallel_2D& para_E
8787
}
8888

8989

90+
template <>
91+
void cal_bra_op_ket<double>(const Parallel_Orbitals* ParaV, const Parallel_2D& para_Eij_in,
92+
const double& wfc, const double& H_wfc, std::vector<double>& Dmn)
93+
{
94+
const int one_int = 1;
95+
const double one_double = 1.0;
96+
const double zero_double = 0.0;
97+
const char N_char = 'N';
98+
const char T_char = 'T';
99+
100+
const int nrow_bands = para_Eij_in.get_row_size();
101+
const int ncol_bands = para_Eij_in.get_col_size();
102+
103+
#ifdef __MPI
104+
const int nbasis = ParaV->desc[2];
105+
const int nbands = ParaV->desc_wfc[3];
106+
107+
pdgemm_( &T_char, &N_char, &nbands, &nbands, &nbasis, &one_double, &wfc, &one_int, &one_int, ParaV->desc_wfc,
108+
&H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_double, &Dmn[0], &one_int, &one_int, para_Eij_in.desc );
109+
#endif
110+
}
111+
112+
90113
// occNum_wfcHwfc = occNum*wfcHwfc + occNum_wfcHwfc
91114
// When symbol = 0, 1, 2, 3, 4, occNum = occNum, 0.5*occNum, g(occNum), 0.5*g(occNum), d_g(occNum)/d_occNum respectively. Default symbol=0.
92115
void occNum_Mul_wfcHwfc(const ModuleBase::matrix& occ_number,

source/module_rdmft/rdmft_tools.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,57 @@ void psiDotPsi<double>(const Parallel_Orbitals* ParaV, const Parallel_2D& para_w
169169
const double& wfc, const double& H_wfc, std::vector<double>& Dmn, double* wfcHwfc);
170170

171171

172+
//! implement matrix multiplication of sum_mu conj(wfc(ik, m ,mu)) * op_wfc(ik, n, mu)
173+
template <typename TK>
174+
void cal_bra_op_ket(const Parallel_Orbitals* ParaV, const Parallel_2D& para_Eij_in, const TK& wfc, const TK& H_wfc, std::vector<TK>& Dmn)
175+
{
176+
const int one_int = 1;
177+
const std::complex<double> one_complex = {1.0, 0.0};
178+
const std::complex<double> zero_complex = {0.0, 0.0};
179+
const char N_char = 'N';
180+
const char C_char = 'C';
181+
182+
const int nrow_bands = para_Eij_in.get_row_size();
183+
const int ncol_bands = para_Eij_in.get_col_size();
184+
185+
#ifdef __MPI
186+
const int nbasis = ParaV->desc[2];
187+
const int nbands = ParaV->desc_wfc[3];
188+
189+
pzgemm_( &C_char, &N_char, &nbands, &nbands, &nbasis, &one_complex, &wfc, &one_int, &one_int, ParaV->desc_wfc,
190+
&H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_complex, &Dmn[0], &one_int, &one_int, para_Eij_in.desc );
191+
#endif
192+
}
193+
194+
195+
template <>
196+
void cal_bra_op_ket<double>(const Parallel_Orbitals* ParaV, const Parallel_2D& para_Eij_in,
197+
const double& wfc, const double& H_wfc, std::vector<double>& Dmn);
198+
199+
200+
//! for Dmn that conforms to the 2d-block rule, get its diagonal elements
201+
template <typename TK>
202+
void _diagonal_in_serial(const Parallel_2D& para_Eij_in, const std::vector<TK>& Dmn, double* wfcHwfc)
203+
{
204+
const int nrow_bands = para_Eij_in.get_row_size();
205+
const int ncol_bands = para_Eij_in.get_col_size();
206+
207+
for(int i=0; i<nrow_bands; ++i)
208+
{
209+
int i_global = para_Eij_in.local2global_row(i);
210+
for(int j=0; j<ncol_bands; ++j)
211+
{
212+
int j_global = para_Eij_in.local2global_col(j);
213+
if(i_global==j_global)
214+
{
215+
// because the Dmn obtained from pzgemm_() is stored column-major
216+
wfcHwfc[j_global] = std::real( Dmn[i+j*nrow_bands] );
217+
}
218+
}
219+
}
220+
}
221+
222+
172223
//! realize occNum_wfc = occNum * wfc. Calling this function and we can get wfc = occNum*wfc.
173224
template <typename TK>
174225
void occNum_MulPsi(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& occ_number, psi::Psi<TK>& wfc, int symbol = 0,

0 commit comments

Comments
 (0)