Skip to content

Commit 641a9a0

Browse files
committed
move rdmft related functions and their calls from esolver_ks to esolver_ks_lcao, and so on
1 parent af28473 commit 641a9a0

File tree

8 files changed

+73
-69
lines changed

8 files changed

+73
-69
lines changed

abaInstall_HZWpara.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
rm -rf build
22

3+
#cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DMPI_CXX_COMPILER=mpiicpc -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DLIBCOMM_DIR=/public1/home/t6s000394/jghan/software/LibComm -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
34
cmake -B build -DCMAKE_INSTALL_PREFIX=/public1/home/t6s000394/jghan/software/abacus-develop/rdmft-abacus/ -DCMAKE_CXX_COMPILER=icpx -DMPI_CXX_COMPILER=mpiicpc -DELPA_DIR=/public1/home/t6s000394/jghan/software/elpa-2021.11/ -DLibxc_DIR=/public1/home/t6s000394/jghan/software/libxc/ -DLIBRI_DIR=/public1/home/t6s000394/jghan/software/LibRI -DLIBCOMM_DIR=/public1/home/t6s000394/jghan/software/LibComm -DCEREAL_INCLUDE_DIR=/public1/home/t6s000394/jghan/software/cereal/cereal-1.3.2/include
45

56
#cmake --build build -j 52 2>job.err

source/module_esolver/esolver_ks.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@
2828

2929
#include "print_funcs.h" // mohan add 2024-07-27
3030

31-
// test by jghan
32-
#include "module_rdmft/rdmft_tools.h"
33-
3431
namespace ModuleESolver
3532
{
3633

@@ -596,10 +593,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
596593
// this->phamilt->update(conv_esolver);
597594
this->update_pot(istep, iter);
598595

599-
// 9.5) rdmft, add by jghan 2024-10-09
600-
bool one_step_exx = false;
601-
if( GlobalC::exx_info.info_global.cal_exx && this->conv_esolver ) one_step_exx = true;
602-
603596
// 10) finish scf iterations
604597
this->iter_finish(iter);
605598
#ifdef __MPI
@@ -630,36 +623,6 @@ void ESolver_KS<T, Device>::runner(const int istep, UnitCell& ucell)
630623
duration);
631624
#endif //__RAPIDJSON
632625

633-
// 12.5) rdmft, add by jghan 2024-04-08/2024-10-09
634-
if ( PARAM.inp.ab_initio_type == "rdmft" )
635-
{
636-
ModuleBase::TITLE("RDMFT", "E & Egradient");
637-
ModuleBase::timer::tick("RDMFT", "E & Egradient");
638-
639-
// if ( (!GlobalC::exx_info.info_global.cal_exx && iter == 1) || one_step_exx )
640-
if ( !GlobalC::exx_info.info_global.cal_exx || (GlobalC::exx_info.info_global.cal_exx && one_step_exx) )
641-
{
642-
ModuleBase::matrix occ_number_ks(this->pelec->wg);
643-
for(int ik=0; ik < occ_number_ks.nr; ++ik)
644-
{
645-
for(int inb=0; inb < occ_number_ks.nc; ++inb) occ_number_ks(ik, inb) /= this->kv.wk[ik];
646-
}
647-
648-
this->update_elec_rdmft(occ_number_ks, *(this->psi));
649-
650-
//initialize the gradients of Etotal on occupation numbers and wfc, and set all elements to 0.
651-
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
652-
psi::Psi<T> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
653-
dE_dWfc.zero_out();
654-
655-
double Etotal_RDMFT = this->run_rdmft(dE_dOccNum, dE_dWfc);
656-
657-
ModuleBase::timer::tick("RDMFT", "E & Egradient");
658-
659-
// break;
660-
}
661-
}
662-
663626
// 13) check convergence
664627
if (this->conv_esolver)
665628
{

source/module_esolver/esolver_ks.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ class ESolver_KS : public ESolver_FP
7171
//! <Temporary> It should be replaced by a function in Hamilt Class
7272
virtual void update_pot(const int istep, const int iter) {};
7373

74-
virtual double run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<T>& E_gradient_wfc) { return 0.0; }; // add by jghan, 2024-03-16
75-
76-
virtual void update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<T>& wfc_in) {}; // add by jghan, 2024-03-16
77-
7874
protected:
7975
// Print inforamtion in each iter
8076
// G1 -3.435545e+03 0.000000e+00 3.607e-01 2.862e-01

source/module_esolver/esolver_ks_lcao.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@
5757

5858
// test RDMFT
5959
#include "module_rdmft/rdmft.h"
60-
#include "module_rdmft/rdmft_tools.h"
61-
#include "module_rdmft/rdmft_test.h"
6260
#include <iostream>
6361

6462
namespace ModuleESolver
@@ -261,17 +259,14 @@ void ESolver_KS_LCAO<TK, TR>::before_all_runners(const Input_para& inp, UnitCell
261259
}
262260

263261

264-
// add by jghan for rdmft calculation
262+
// 15) initialize rdmft, added by jghan
265263
if( PARAM.inp.ab_initio_type == "rdmft" )
266264
{
267265
rdmft_solver.init( this->GG, this->GK, this->pv, ucell, this->kv, *(this->pelec),
268266
this->orb_, two_center_bundle_, PARAM.inp.dft_functional, PARAM.inp.rdmft_power_alpha);
269-
270-
// the initialization and necessary calculations of these quantities have been completed in init()
271-
// rdmft_solver.update_ion(ucell, LM, *(this->pw_rho), GlobalC::ppcell.vloc, this->sf.strucFac, this->LOC);
272267
}
273268

274-
// 15) if kpar is not divisible by nks, print a warning
269+
// 16) if kpar is not divisible by nks, print a warning
275270
if (GlobalV::KPAR_LCAO > 1)
276271
{
277272
if (this->kv.get_nks() % GlobalV::KPAR_LCAO != 0)
@@ -955,6 +950,10 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int& iter)
955950
}
956951
}
957952

953+
// 2.5) determine whether to update exx, added by jghan, 2024-10-25
954+
bool one_step_exx = false;
955+
if( GlobalC::exx_info.info_global.cal_exx && this->conv_esolver ) one_step_exx = true;
956+
958957
#ifdef __EXX
959958
// 3) save exx matrix
960959
int two_level_step = GlobalC::exx_info.info_ri.real_number ? this->exd->two_level_step : this->exc->two_level_step;
@@ -1102,6 +1101,38 @@ void ESolver_KS_LCAO<TK, TR>::iter_finish(int& iter)
11021101
{
11031102
GlobalC::dftu.initialed_locale = true;
11041103
}
1104+
1105+
// 7) rdmft, added by jghan, 2024-10-25
1106+
if ( PARAM.inp.ab_initio_type == "rdmft" )
1107+
{
1108+
ModuleBase::TITLE("RDMFT", "E & Egradient");
1109+
ModuleBase::timer::tick("RDMFT", "E & Egradient");
1110+
1111+
// if ( (!GlobalC::exx_info.info_global.cal_exx && iter == 1) || one_step_exx )
1112+
if ( !GlobalC::exx_info.info_global.cal_exx || (GlobalC::exx_info.info_global.cal_exx && one_step_exx) )
1113+
{
1114+
ModuleBase::matrix occ_number_ks(this->pelec->wg);
1115+
for(int ik=0; ik < occ_number_ks.nr; ++ik)
1116+
{
1117+
for(int inb=0; inb < occ_number_ks.nc; ++inb) occ_number_ks(ik, inb) /= this->kv.wk[ik];
1118+
}
1119+
1120+
this->update_elec_rdmft(occ_number_ks, *(this->psi));
1121+
1122+
//initialize the gradients of Etotal on occupation numbers and wfc, and set all elements to 0.
1123+
ModuleBase::matrix dE_dOccNum(this->pelec->wg.nr, this->pelec->wg.nc, true);
1124+
psi::Psi<TK> dE_dWfc(this->psi->get_nk(), this->psi->get_nbands(), this->psi->get_nbasis());
1125+
dE_dWfc.zero_out();
1126+
1127+
double Etotal_RDMFT = this->run_rdmft(dE_dOccNum, dE_dWfc);
1128+
1129+
ModuleBase::timer::tick("RDMFT", "E & Egradient");
1130+
1131+
// break;
1132+
one_step_exx = false;
1133+
}
1134+
}
1135+
11051136
}
11061137

11071138
//------------------------------------------------------------------------------
@@ -1348,13 +1379,13 @@ void ESolver_KS_LCAO<TK, TR>::after_scf(const int istep)
13481379
template <typename TK, typename TR>
13491380
double ESolver_KS_LCAO<TK, TR>::run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>& E_gradient_wfc)
13501381
{
1351-
return rdmft_solver.run(E_gradient_occNum, E_gradient_wfc);
1382+
return this->rdmft_solver.run(E_gradient_occNum, E_gradient_wfc);
13521383
}
13531384

13541385
template <typename TK, typename TR>
13551386
void ESolver_KS_LCAO<TK, TR>::update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<TK>& wfc_in)
13561387
{
1357-
rdmft_solver.update_elec(occ_number_in, wfc_in);
1388+
this->rdmft_solver.update_elec(occ_number_in, wfc_in);
13581389
}
13591390

13601391

source/module_esolver/esolver_ks_lcao.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "module_basis/module_nao/two_center_bundle.h"
1313
#include "module_io/output_mat_sparse.h"
1414

15-
// add by jghan for rdmft calculation
15+
// added by jghan for rdmft calculation
1616
#include "module_rdmft/rdmft.h"
1717

1818
#include <memory>
@@ -48,9 +48,9 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
4848

4949
void cal_mag(const int istep, const bool print = false);
5050

51-
virtual double run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>& E_gradient_wfc) override; // add by jghan for rdmft calculation, 2024-03-16
51+
double run_rdmft(ModuleBase::matrix& E_gradient_occNum, psi::Psi<TK>& E_gradient_wfc); // added by jghan for rdmft calculation, 2024-03-16
5252

53-
virtual void update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<TK>& wfc_in) override; // add by jghan for rdmft calculation, 2024-03-16
53+
void update_elec_rdmft(const ModuleBase::matrix& occ_number_in, const psi::Psi<TK>& wfc_in); // added by jghan for rdmft calculation, 2024-03-16
5454

5555
protected:
5656
virtual void before_scf(const int istep) override;
@@ -85,7 +85,7 @@ class ESolver_KS_LCAO : public ESolver_KS<TK> {
8585

8686
TwoCenterBundle two_center_bundle_;
8787

88-
rdmft::RDMFT<TK, TR> rdmft_solver; // add by jghan for rdmft calculation
88+
rdmft::RDMFT<TK, TR> rdmft_solver; // added by jghan for rdmft calculation
8989

9090
// temporary introduced during removing GlobalC::ORB
9191
LCAO_Orbitals orb_;

source/module_rdmft/rdmft.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,10 @@ RDMFT<TK, TR>::~RDMFT()
6565
// delete HR_local;
6666
delete HR_dft_XC;
6767

68+
#ifdef __EXX
6869
delete Vxc_fromRI_d;
6970
delete Vxc_fromRI_c;
71+
#endif
7072

7173
delete V_ekinetic_potential;
7274
delete V_nonlocal;
@@ -97,7 +99,7 @@ void RDMFT<TK, TR>::init(Gint_Gamma& GG_in, Gint_k& GK_in, Parallel_Orbitals& Pa
9799
// if (ModuleSymmetry::Symmetry::symm_flag == -1) nk_total = kv->nkstot_full;
98100
// else nk_total = kv->nks;
99101

100-
nk_total = ModuleSymmetry::Symmetry::symm_flag == -1 ? kv->nkstot_full: kv->nks;
102+
nk_total = ModuleSymmetry::Symmetry::symm_flag == -1 ? kv->get_nkstot_full(): kv->get_nks();
101103
nbands_total = PARAM.inp.nbands;
102104
// nbands_total = GlobalV::NBANDS;
103105
nspin = PARAM.inp.nspin;
@@ -187,6 +189,7 @@ void RDMFT<TK, TR>::init(Gint_Gamma& GG_in, Gint_k& GK_in, Parallel_Orbitals& Pa
187189
HR_dft_XC->set_zero();
188190
// HR_local->set_zero();
189191

192+
#ifdef __EXX
190193
if( GlobalC::exx_info.info_global.cal_exx )
191194
{
192195
if (GlobalC::exx_info.info_ri.real_number)
@@ -200,6 +203,7 @@ void RDMFT<TK, TR>::init(Gint_Gamma& GG_in, Gint_k& GK_in, Parallel_Orbitals& Pa
200203
Vxc_fromRI_c->init(MPI_COMM_WORLD, *kv, *orb);
201204
}
202205
}
206+
#endif
203207

204208
if( PARAM.inp.gamma_only )
205209
{
@@ -228,6 +232,7 @@ void RDMFT<TK, TR>::update_ion(UnitCell& ucell_in, ModulePW::PW_Basis& rho_basis
228232
HR_TV->set_zero();
229233
this->cal_V_TV();
230234

235+
#ifdef __EXX
231236
if( GlobalC::exx_info.info_global.cal_exx )
232237
{
233238
if (GlobalC::exx_info.info_ri.real_number)
@@ -239,6 +244,7 @@ void RDMFT<TK, TR>::update_ion(UnitCell& ucell_in, ModulePW::PW_Basis& rho_basis
239244
Vxc_fromRI_c->cal_exx_ions();
240245
}
241246
}
247+
#endif
242248

243249
std::cout << "\n\n\n******\ndo rdmft_esolver.update_ion() successfully\n******\n\n\n" << std::endl;
244250
}
@@ -648,7 +654,8 @@ void RDMFT<TK, TR>::cal_V_XC()
648654
}
649655
V_dft_XC->contributeHR();
650656
}
651-
657+
658+
#ifdef __EXX
652659
if(GlobalC::exx_info.info_global.cal_exx)
653660
{
654661
if (GlobalC::exx_info.info_ri.real_number)
@@ -704,6 +711,8 @@ void RDMFT<TK, TR>::cal_V_XC()
704711
// use hamilt::Add_Hexx_Type::k, not R, contributeHR() should be skipped
705712
// V_exx_XC->contributeHR();
706713
}
714+
#endif
715+
707716
}
708717

709718

@@ -736,6 +745,7 @@ void RDMFT<TK, TR>::cal_Hk_Hpsi()
736745
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_TV(ik, 0, 0), Eij_TV, &(wfcHwfc_TV(ik, 0)) );
737746
psiDotPsi( ParaV, para_Eij, wfc(ik, 0, 0), H_wfc_hartree(ik, 0, 0), Eij_hartree, &(wfcHwfc_hartree(ik, 0)) );
738747

748+
#ifdef __EXX
739749
if(GlobalC::exx_info.info_global.cal_exx)
740750
{
741751
// set_zero_vector(HK_exx_XC);
@@ -747,6 +757,7 @@ void RDMFT<TK, TR>::cal_Hk_Hpsi()
747757

748758
for(int iloc=0; iloc<HK_XC.size(); ++iloc) HK_XC[iloc] += hsk_exx_XC->get_hk()[iloc];
749759
}
760+
#endif
750761

751762
if( !only_exx_type )
752763
{
@@ -900,7 +911,7 @@ void RDMFT<TK, TR>::cal_Energy(const int cal_type)
900911
<< "\nExc_" << XC_func_rdmft << "_RDMFT: " << E_RDMFT[2]
901912
<< "\nE_Ewald: " << E_Ewald
902913
<< "\nE_entropy(-TS): " << E_entropy
903-
<< "\nE_descf: " << E_descf
914+
<< "\nE_descf: " << E_descf
904915
<< "\n\nEtotal_RDMFT: " << Etotal
905916
<< "\n\nExc_ksdft: " << E_xc_KS
906917
<< "\nE_exx_ksdft: " << E_exx_KS
@@ -927,7 +938,7 @@ void RDMFT<TK, TR>::cal_Energy(const int cal_type)
927938
<<"\n******\n" << std::endl;
928939
}
929940

930-
ModuleBase::timer::tick("rdmftTest", "RDMFT_E&Egradient");
941+
std::cout << std::defaultfloat;
931942

932943
}
933944

source/module_rdmft/rdmft.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,25 @@
2323
#include "module_base/parallel_reduce.h"
2424
// #include "module_elecstate/module_dm/cal_dm_psi.h"
2525
#include "module_elecstate/module_dm/density_matrix.h"
26+
#include "module_hamilt_lcao/hamilt_lcaodft/hs_matrix_k.hpp"
2627

2728
#include "module_basis/module_ao/ORB_read.h"
2829
#include "module_basis/module_nao/two_center_bundle.h"
2930

3031
#include "module_hamilt_general/operator.h"
3132
#include "module_hamilt_lcao/module_hcontainer/hcontainer.h"
3233
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.h"
33-
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h"
3434
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/ekinetic_new.h"
3535
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/nonlocal_new.h"
3636
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/veff_lcao.h"
3737

38-
#include "module_hamilt_lcao/hamilt_lcaodft/hs_matrix_k.hpp"
39-
40-
// #include "module_directmin/manifold/stiefel.h"
4138

4239
// used by Exx&LRI
43-
#include "module_ri/RI_2D_Comm.h"
40+
#ifdef __EXX
4441
#include "module_ri/Exx_LRI.h"
42+
#include "module_ri/RI_2D_Comm.h"
43+
#include "module_hamilt_lcao/hamilt_lcaodft/operator_lcao/op_exx_lcao.h"
44+
#endif
4545

4646
// there are some operator reload to print data in different formats
4747
#include "module_ri/test_code/test_function.h"
@@ -176,8 +176,10 @@ class RDMFT
176176
hamilt::OperatorLCAO<TK, TR>* V_hartree_XC = nullptr;
177177
// bool get_V_local_temp = true;
178178

179+
#ifdef __EXX
179180
Exx_LRI<double>* Vxc_fromRI_d = nullptr;
180181
Exx_LRI<std::complex<double>>* Vxc_fromRI_c = nullptr;
182+
#endif
181183

182184
double Etotal = 0.0;
183185
double etxc = 0.0;

source/module_ri/Exx_LRI.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class Exx_LRI
5959
void init(const MPI_Comm &mpi_comm_in, const K_Vectors &kv_in, const LCAO_Orbitals& orb);
6060
void cal_exx_force();
6161
void cal_exx_stress();
62+
void cal_exx_ions(const bool write_cv = false);
63+
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
64+
const Parallel_Orbitals& pv,
65+
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
6266
std::vector<std::vector<int>> get_abfs_nchis() const;
6367

6468
std::vector< std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>> Hexxs;
@@ -71,7 +75,7 @@ class Exx_LRI
7175
const Exx_Info::Exx_Info_RI &info;
7276
MPI_Comm mpi_comm;
7377
const K_Vectors *p_kv = nullptr;
74-
std::vector<double> orb_cutoff_;
78+
std::vector<double> orb_cutoff_;
7579

7680
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> lcaos;
7781
std::vector<std::vector<std::vector<Numerical_Orbital_Lm>>> abfs;
@@ -80,10 +84,6 @@ class Exx_LRI
8084
LRI_CV<Tdata> cv;
8185
RI::Exx<TA,Tcell,Ndim,Tdata> exx_lri;
8286

83-
void cal_exx_ions(const bool write_cv = false);
84-
void cal_exx_elec(const std::vector<std::map<TA, std::map<TAC, RI::Tensor<Tdata>>>>& Ds,
85-
const Parallel_Orbitals& pv,
86-
const ModuleSymmetry::Symmetry_rotation* p_symrot = nullptr);
8787
void post_process_Hexx( std::map<TA, std::map<TAC, RI::Tensor<Tdata>>> &Hexxs_io ) const;
8888
double post_process_Eexx(const double& Eexx_in) const;
8989

@@ -99,4 +99,4 @@ class Exx_LRI
9999

100100
#include "Exx_LRI.hpp"
101101

102-
#endif
102+
#endif

0 commit comments

Comments
 (0)