Skip to content

Commit d0b366c

Browse files
committed
prepare for PR
2 parents 600abc0 + 3c13b1a commit d0b366c

File tree

7 files changed

+60
-69
lines changed

7 files changed

+60
-69
lines changed

source/Makefile.Objects

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ VPATH=./src_global:\
7272
./module_lr/operator_casida:\
7373
./module_lr/potentials:\
7474
./module_lr/utils:\
75+
./module_rdmft:\
7576
./\
7677

7778
OBJS_ABACUS_PW=${OBJS_MAIN}\
@@ -113,6 +114,7 @@ ${OBJS_DELTASPIN}\
113114
${OBJS_TENSOR}\
114115
${OBJS_HSOLVER_PEXSI}\
115116
${OBJS_LR}\
117+
${OBJS_RDMFT}\
116118

117119
OBJS_MAIN=main.o\
118120
driver.o\
@@ -731,3 +733,7 @@ OBJS_TENSOR=tensor.o\
731733
lr_spectrum.o\
732734
hamilt_casida.o\
733735
esolver_lrtd_lcao.o\
736+
737+
OBJS_RDMFT=rdmft.o\
738+
rdmft_tools.o\
739+
rdmft_test.o\

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -955,14 +955,11 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int& iter)
955955
// 2.5) determine whether rdmft needs to get the initial value, added by jghan, 2024-10-25
956956
bool one_step_exx = false;
957957
bool get_init_value_rdmft = false;
958-
// the case without hybrid functionals
959-
if( iter == 1 ) get_init_value_rdmft = true;
960-
958+
if( iter == 1 ) get_init_value_rdmft = true; // the case without hybrid functionals
961959
#ifdef __EXX
962960
if( GlobalC::exx_info.info_global.cal_exx )
963961
{
964962
if( this->conv_esolver ) one_step_exx = true;
965-
966963
// the case with hybrid functionals
967964
if( one_step_exx && iter==1 ) get_init_value_rdmft = true;
968965
else get_init_value_rdmft = false;
@@ -1126,8 +1123,6 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int& iter)
11261123
ModuleBase::TITLE("RDMFT", "E & Egradient");
11271124
ModuleBase::timer::tick("RDMFT", "E & Egradient");
11281125

1129-
// if ( (!GlobalC::exx_info.info_global.cal_exx && iter == 1) || one_step_exx )
1130-
// if ( !GlobalC::exx_info.info_global.cal_exx || (GlobalC::exx_info.info_global.cal_exx && one_step_exx) )
11311126
if( get_init_value_rdmft )
11321127
{
11331128
ModuleBase::matrix occ_number_ks(this->pelec->wg);
@@ -1136,14 +1131,14 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int& iter)
11361131
for(int inb=0; inb < occ_number_ks.nc; ++inb) occ_number_ks(ik, inb) /= this->kv.wk[ik];
11371132
}
11381133

1139-
this->update_elec_rdmft(occ_number_ks, *(this->psi));
1134+
this->rdmft_solver.update_elec(occ_number_ks, *(this->psi));
11401135

11411136
//initialize the gradients of Etotal on occupation numbers and wfc, and set all elements to 0.
11421137
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
11431138
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
11441139
dE_dWfc.zero_out();
11451140

1146-
double Etotal_RDMFT = this->run_rdmft(dE_dOccNum, dE_dWfc);
1141+
double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc);
11471142

11481143
ModuleBase::timer::tick("RDMFT", "E & Egradient");
11491144

@@ -1256,14 +1251,14 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
12561251
for(int inb=0; inb < occ_number_ks.nc; ++inb) { occ_number_ks(ik, inb) /= this->kv.wk[ik];
12571252
}
12581253
}
1259-
this->update_elec_rdmft(occ_number_ks, *(this->psi));
1254+
this->rdmft_solver.update_elec(occ_number_ks, *(this->psi));
12601255

12611256
//initialize the gradients of Etotal on occupation numbers and wfc, and set all elements to 0.
12621257
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
12631258
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
12641259
dE_dWfc.zero_out();
12651260

1266-
double Etotal_RDMFT = this->run_rdmft(dE_dOccNum, dE_dWfc);
1261+
double Etotal_RDMFT = this->rdmft_solver.run(dE_dOccNum, dE_dWfc);
12671262
}
12681263

12691264
/******** test RDMFT *********/
@@ -1395,19 +1390,6 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
13951390
}
13961391

13971392

1398-
template <typename TK, typename TR>
1399-
double ESolver_KS_LCAO<TK, TR>::run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>& E_gradient_wfc)
1400-
{
1401-
return this->rdmft_solver.run(E_gradient_occNum, E_gradient_wfc);
1402-
}
1403-
1404-
template <typename TK, typename TR>
1405-
void ESolver_KS_LCAO<TK, TR>::update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<TK>& wfc_in)
1406-
{
1407-
this->rdmft_solver.update_elec(occ_number_in, wfc_in);
1408-
}
1409-
1410-
14111393
//------------------------------------------------------------------------------
14121394
//! the 20th,21th,22th functions of ESolver_KS_LCAO
14131395
//! mohan add 2024-05-11

source/module_esolver/esolver_ks_lcao.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,6 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
4848

4949
void cal_mag(const int istep, const bool print = false);
5050

51-
double run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>& E_gradient_wfc); // added by jghan for rdmft calculation, 2024-03-16
52-
53-
void update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<TK>& wfc_in); // added by jghan for rdmft calculation, 2024-03-16
54-
5551
protected:
5652
virtual void before_scf(const int istep) override;
5753

@@ -85,7 +81,7 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
8581

8682
TwoCenterBundle two_center_bundle_;
8783

88-
rdmft::RDMFT<TK, TR> rdmft_solver; // added by jghan for rdmft calculation
84+
rdmft::RDMFT<TK, TR> rdmft_solver; // added by jghan for rdmft calculation, 2024-03-16
8985

9086
// temporary introduced during removing GlobalC::ORB
9187
LCAO_Orbitals orb_;

source/module_rdmft/rdmft.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,9 @@ void RDMFT<TK, TR>::init(Gint_Gamma& GG_in, Gint_k& GK_in, Parallel_Orbitals& Pa
117117
// para_Eij.blacs_ctxt = ParaV->blacs_ctxt;
118118
// para_Eij.set_local2global( GlobalV::NBANDS, GlobalV::NBANDS, ofs_running, ofs_warning );
119119
// para_Eij.set_desc( GlobalV::NBANDS, GlobalV::NBANDS, para_Eij.get_row_size(), false );
120-
120+
#ifdef __MPI
121121
para_Eij.set(nbands_total, nbands_total, ParaV->nb, ParaV->blacs_ctxt); // maybe in default, PARAM.inp.nb2d = 0, can't be used
122+
#endif
122123
// para_Eij.init(nbands_total, nbands_total, PARAM.inp.nb2d, MPI_COMM_WORLD);
123124
// // learn from "module_hamilt_lcao/hamilt_lcaodft/LCAO_init_basis.cpp"
124125

@@ -892,29 +893,33 @@ void RDMFT<TK, TR>::cal_Energy(const int cal_type)
892893
// }
893894
}
894895

895-
// print results
896-
std::cout << "\n\nfrom class RDMFT: \nXC_fun: " << XC_func_rdmft << std::endl;
896+
// // print results
897+
// std::cout << "\n\nfrom class RDMFT: \nXC_fun: " << XC_func_rdmft << std::endl;
898+
// #ifdef __EXX
899+
// if( GlobalC::exx_info.info_global.cal_exx ) std::cout << "alpha_power: " << alpha_power << std::endl;
900+
// #endif
901+
// std::cout << std::fixed << std::setprecision(10)
902+
// << "******\nE(TV + Hartree + XC) by RDMFT: " << E_RDMFT[3]
903+
// << "\n\nE_TV_RDMFT: " << E_RDMFT[0]
904+
// << "\nE_hartree_RDMFT: " << E_RDMFT[1]
905+
// << "\nExc_" << XC_func_rdmft << "_RDMFT: " << E_RDMFT[2]
906+
// << "\nE_Ewald: " << E_Ewald
907+
// << "\nE_entropy(-TS): " << E_entropy
908+
// << "\nE_descf: " << E_descf
909+
// << "\n\nEtotal_RDMFT: " << Etotal
910+
// << "\n\nExc_ksdft: " << E_xc_KS
911+
// << "\nE_exx_ksdft: " << E_exx_KS
912+
// <<"\n******\n\n" << std::endl;
913+
914+
// std::cout << "\netxc: " << etxc << "\nvtxc: " << vtxc << "\n";
915+
// std::cout << "\nE_deband_KS: " << E_deband_KS << "\nE_deband_harris_KS: " << E_deband_harris_KS << "\n\n" << std::endl;
916+
917+
if( PARAM.inp.ab_initio_type == "rdmft" )
918+
{
919+
GlobalV::ofs_running << "\n\nfrom class RDMFT: \nXC_fun: " << XC_func_rdmft << std::endl;
897920
#ifdef __EXX
898-
if( GlobalC::exx_info.info_global.cal_exx ) std::cout << "alpha_power: " << alpha_power << std::endl;
921+
if( GlobalC::exx_info.info_global.cal_exx ) GlobalV::ofs_running << "alpha_power: " << alpha_power << std::endl;
899922
#endif
900-
std::cout << std::fixed << std::setprecision(10)
901-
<< "******\nE(TV + Hartree + XC) by RDMFT: " << E_RDMFT[3]
902-
<< "\n\nE_TV_RDMFT: " << E_RDMFT[0]
903-
<< "\nE_hartree_RDMFT: " << E_RDMFT[1]
904-
<< "\nExc_" << XC_func_rdmft << "_RDMFT: " << E_RDMFT[2]
905-
<< "\nE_Ewald: " << E_Ewald
906-
<< "\nE_entropy(-TS): " << E_entropy
907-
<< "\nE_descf: " << E_descf
908-
<< "\n\nEtotal_RDMFT: " << Etotal
909-
<< "\n\nExc_ksdft: " << E_xc_KS
910-
<< "\nE_exx_ksdft: " << E_exx_KS
911-
<<"\n******\n\n" << std::endl;
912-
913-
std::cout << "\netxc: " << etxc << "\nvtxc: " << vtxc << "\n";
914-
std::cout << "\nE_deband_KS: " << E_deband_KS << "\nE_deband_harris_KS: " << E_deband_harris_KS << "\n\n" << std::endl;
915-
916-
if( 1 )
917-
{
918923
// GlobalV::ofs_running << std::setprecision(12);
919924
// GlobalV::ofs_running << std::setiosflags(std::ios::right);
920925
GlobalV::ofs_running << std::fixed << std::setprecision(10)
@@ -930,7 +935,6 @@ void RDMFT<TK, TR>::cal_Energy(const int cal_type)
930935
<< "\nE_exx_ksdft: " << E_exx_KS
931936
<<"\n******\n" << std::endl;
932937
}
933-
934938
std::cout << std::defaultfloat;
935939

936940
}

source/module_rdmft/rdmft.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,6 @@ class RDMFT
204204
// // update occ_number for optimization algorithms that depend on Hamilton
205205
// void update_wg(const ModuleBase::matrix& wg_in);
206206

207-
// get the total Hamilton in k-space
208-
void cal_Hk_Hpsi();
209-
210207
// do all calculation after update occNum&wfc, get Etotal and the gradient of energy with respect to the occNum&wfc
211208
double run(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>&E_gradient_wfc);
212209

@@ -231,16 +228,14 @@ class RDMFT
231228

232229

233230
private:
231+
232+
// get the total Hamilton in k-space
233+
void cal_Hk_Hpsi();
234234

235235
void update_charge();
236236

237237

238238

239-
240-
241-
242-
243-
244239
};
245240

246241

source/module_rdmft/rdmft_tools.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,14 @@ void HkPsi<double>(const Parallel_Orbitals* ParaV, const double& HK, const doubl
6060
const char N_char = 'N';
6161
const char C_char = 'C';
6262

63+
#ifdef __MPI
6364
const int nbasis = ParaV->desc[2];
6465
const int nbands = ParaV->desc_wfc[3];
6566

6667
//because wfc(bands, basis'), H(basis, basis'), we do wfc*H^T(in the perspective of cpp, not in fortran). And get H_wfc(bands, basis) is correct.
6768
pdgemm_( &C_char, &N_char, &nbasis, &nbands, &nbasis, &one_double, &HK, &one_int, &one_int, ParaV->desc,
6869
&wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_double, &H_wfc, &one_int, &one_int, ParaV->desc_wfc );
70+
#endif
6971

7072
}
7173

@@ -80,14 +82,16 @@ void psiDotPsi<double>(const Parallel_Orbitals* ParaV, const Parallel_2D& para_E
8082
const char N_char = 'N';
8183
const char T_char = 'T';
8284

83-
const int nbasis = ParaV->desc[2];
84-
const int nbands = ParaV->desc_wfc[3];
85-
8685
const int nrow_bands = para_Eij_in.get_row_size();
8786
const int ncol_bands = para_Eij_in.get_col_size();
8887

88+
#ifdef __MPI
89+
const int nbasis = ParaV->desc[2];
90+
const int nbands = ParaV->desc_wfc[3];
91+
8992
pdgemm_( &T_char, &N_char, &nbands, &nbands, &nbasis, &one_double, &wfc, &one_int, &one_int, ParaV->desc_wfc,
9093
&H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_double, &Dmn[0], &one_int, &one_int, para_Eij_in.desc );
94+
#endif
9195

9296
for(int i=0; i<nrow_bands; ++i)
9397
{

source/module_rdmft/rdmft_tools.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,14 @@ void HkPsi(const Parallel_Orbitals* ParaV, const TK& HK, const TK& wfc, TK& H_wf
136136
const char N_char = 'N';
137137
const char C_char = 'C'; // Using 'C' is consistent with the formula
138138

139+
#ifdef __MPI
139140
const int nbasis = ParaV->desc[2];
140141
const int nbands = ParaV->desc_wfc[3];
141142

142143
//because wfc(bands, basis'), H(basis, basis'), we do wfc*H^T(in the perspective of cpp, not in fortran). And get H_wfc(bands, basis) is correct.
143144
pzgemm_( &C_char, &N_char, &nbasis, &nbands, &nbasis, &one_complex, &HK, &one_int, &one_int, ParaV->desc,
144145
&wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_complex, &H_wfc, &one_int, &one_int, ParaV->desc_wfc );
146+
#endif
145147
}
146148

147149

@@ -158,14 +160,16 @@ void psiDotPsi(const Parallel_Orbitals* ParaV, const Parallel_2D& para_Eij_in, c
158160
const char N_char = 'N';
159161
const char C_char = 'C';
160162

161-
const int nbasis = ParaV->desc[2];
162-
const int nbands = ParaV->desc_wfc[3];
163-
164163
const int nrow_bands = para_Eij_in.get_row_size();
165164
const int ncol_bands = para_Eij_in.get_col_size();
166165

166+
#ifdef __MPI
167+
const int nbasis = ParaV->desc[2];
168+
const int nbands = ParaV->desc_wfc[3];
169+
167170
pzgemm_( &C_char, &N_char, &nbands, &nbands, &nbasis, &one_complex, &wfc, &one_int, &one_int, ParaV->desc_wfc,
168171
&H_wfc, &one_int, &one_int, ParaV->desc_wfc, &zero_complex, &Dmn[0], &one_int, &one_int, para_Eij_in.desc );
172+
#endif
169173

170174
for(int i=0; i<nrow_bands; ++i)
171175
{
@@ -196,8 +200,8 @@ void occNum_MulPsi(const Parallel_Orbitals* ParaV, const ModuleBase::matrix& occ
196200
const int nbands_local = wfc.get_nbands();
197201
const int nbasis_local = wfc.get_nbasis();
198202

199-
const int nbasis = ParaV->desc[2]; // need to be deleted
200-
const int nbands = ParaV->desc_wfc[3];
203+
// const int nbasis = ParaV->desc[2]; // need to be deleted
204+
// const int nbands = ParaV->desc_wfc[3];
201205

202206
for (int ik = 0; ik < nk_local; ++ik)
203207
{
@@ -224,8 +228,8 @@ void add_psi(const Parallel_Orbitals* ParaV, const K_Vectors* kv, const ModuleBa
224228
occNum_MulPsi(ParaV, occ_number, psi_dft_XC);
225229
occNum_MulPsi(ParaV, occ_number, psi_exx_XC, 2, XC_func_rdmft, alpha);
226230

227-
const int nbasis = ParaV->desc[2];
228-
const int nbands = ParaV->desc_wfc[3];
231+
// const int nbasis = ParaV->desc[2];
232+
// const int nbands = ParaV->desc_wfc[3];
229233

230234
for(int ik=0; ik<nk; ++ik)
231235
{

0 commit comments

Comments
 (0)