diff --git a/examples/hse/pw_Si2/INPUT b/examples/hse/pw_Si2/INPUT new file mode 100644 index 0000000000..7e97df18ee --- /dev/null +++ b/examples/hse/pw_Si2/INPUT @@ -0,0 +1,32 @@ +INPUT_PARAMETERS +pseudo_dir ../../../tests/PP_ORB +orbital_dir ../../../tests/PP_ORB +nbands 4 +nspin 1 +calculation scf +basis_type pw +ks_solver dav +ecutwfc 50 +scf_thr 1e-9 +scf_nmax 100 +gamma_only 0 +symmetry -1 +smearing_method fixed +mixing_type broyden +mixing_beta 0.7 + +dft_functional hse + +# init_wfc file +# init_chg file + +pseudo_mesh 1 +pseudo_rcut 10 + +# out_wfc_pw 1 +# out_chg 1 + +exx_hybrid_alpha 0.25 + +cal_stress 1 +# cal_force 1 diff --git a/examples/hse/pw_Si2/KPT b/examples/hse/pw_Si2/KPT new file mode 100644 index 0000000000..8f50366e67 --- /dev/null +++ b/examples/hse/pw_Si2/KPT @@ -0,0 +1,4 @@ +K_POINTS +0 +Gamma +3 3 3 0 0 0 diff --git a/examples/hse/pw_Si2/STRU b/examples/hse/pw_Si2/STRU new file mode 100644 index 0000000000..7d3326d214 --- /dev/null +++ b/examples/hse/pw_Si2/STRU @@ -0,0 +1,24 @@ +ATOMIC_SPECIES +Si 1 Si_ONCV_PBE_FR-1.1.upf #Pseudopotentials are downloaded from http://www.quantum-simulation.org/potentials/sg15_oncv/upf/ + +NUMERICAL_ORBITAL +orb_Si.dat + +LATTICE_CONSTANT +1.889766 + +LATTICE_VECTORS +0.0 2.708337 2.708337 +2.708337 0.0 2.708337 +2.708337 2.708337 0.0 + + + +ATOMIC_POSITIONS +Direct + +Si +0 +2 +0.125 0.125 0.125 0 0 0 +0.875 0.875 0.875 0 0 0 diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 6fd3eeedf9..f15bde444f 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -50,6 +50,7 @@ VPATH=./src_global:\ ./module_hamilt_pw/hamilt_stodft:\ ./module_hamilt_pw/hamilt_pwdft/operator_pw:\ ./module_hamilt_pw/hamilt_pwdft/kernels:\ +./module_hamilt_pw/hamilt_pwdft/module_exx_helper:\ ./module_hamilt_pw/hamilt_stodft/kernels:\ ./module_hamilt_lcao/module_hcontainer:\ ./module_hamilt_lcao/hamilt_lcaodft:\ @@ -295,6 +296,7 @@ OBJS_HAMILT=hamilt_pw.o\ hamilt_sdft_pw.o\ operator.o\ operator_pw.o\ + op_exx_pw.o\ ekinetic_pw.o\ ekinetic_op.o\ hpsi_norm_op.o\ @@ -306,6 +308,7 @@ OBJS_HAMILT=hamilt_pw.o\ meta_op.o\ velocity_pw.o\ radial_proj.o\ + exx_helper.o\ OBJS_HAMILT_OF=kedf_tf.o\ kedf_vw.o\ @@ -671,6 +674,7 @@ OBJS_SRCPW=H_Ewald_pw.o\ sto_stress_pw.o\ stress_func_cc.o\ stress_func_ewa.o\ + stress_func_exx.o\ stress_func_gga.o\ stress_func_mgga.o\ stress_func_har.o\ diff --git a/source/module_elecstate/elecstate.h b/source/module_elecstate/elecstate.h index 3d99ca4910..23fbf17198 100644 --- a/source/module_elecstate/elecstate.h +++ b/source/module_elecstate/elecstate.h @@ -131,12 +131,8 @@ class ElecState bool vnew_exist = false; void cal_converged(); void cal_energies(const int type); -#ifdef __EXX -#ifdef __LCAO void set_exx(const double& Eexx); void set_exx(const std::complex& Eexx); -#endif //__LCAO -#endif //__EXX double get_hartree_energy(); double get_etot_efield(); diff --git a/source/module_elecstate/elecstate_exx.cpp b/source/module_elecstate/elecstate_exx.cpp index 8881a74807..f6d9acfc2d 100644 --- a/source/module_elecstate/elecstate_exx.cpp +++ b/source/module_elecstate/elecstate_exx.cpp @@ -3,8 +3,6 @@ namespace elecstate { -#ifdef __EXX -#ifdef __LCAO /// @brief calculation if converged /// @date Peize Lin add 2016-12-03 void ElecState::set_exx(const double& Eexx) @@ -17,7 +15,5 @@ void ElecState::set_exx(const double& Eexx) } return; } -#endif //__LCAO -#endif //__EXX } \ No newline at end of file diff --git a/source/module_elecstate/potentials/potential_new.h b/source/module_elecstate/potentials/potential_new.h index fd6e087534..5d421548b3 100644 --- a/source/module_elecstate/potentials/potential_new.h +++ b/source/module_elecstate/potentials/potential_new.h @@ -170,6 +170,12 @@ class Potential : public PotBase { return this->v_effective_fixed.data(); } + const ModulePW::PW_Basis *get_rho_basis() const + { + return this->rho_basis_; + } + // What about adding a function to get the wfc? + // This is useful for the calculation of the exx energy /// @brief get the value of vloc at G=0; diff --git a/source/module_esolver/CMakeLists.txt b/source/module_esolver/CMakeLists.txt index 6618006588..6b76e8f84e 100644 --- a/source/module_esolver/CMakeLists.txt +++ b/source/module_esolver/CMakeLists.txt @@ -27,6 +27,8 @@ add_library( esolver OBJECT ${objects} + ../module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp + ../module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h ) if(ENABLE_COVERAGE) diff --git a/source/module_esolver/esolver_ks.cpp b/source/module_esolver/esolver_ks.cpp index 84c8f2525b..9d485f4306 100644 --- a/source/module_esolver/esolver_ks.cpp +++ b/source/module_esolver/esolver_ks.cpp @@ -2,6 +2,7 @@ #include "module_base/timer.h" #include "module_cell/cal_atoms_info.h" +#include "module_hamilt_general/module_xc/xc_functional.h" #include "module_io/cube_io.h" #include "module_io/json_output/init_info.h" #include "module_io/json_output/output_info.h" @@ -24,6 +25,8 @@ #include "module_cell/module_paw/paw_cell.h" #endif +#include "esolver_ks_pw.h" + namespace ModuleESolver { diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index 67b1be64ab..3bddfc2dfa 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -225,6 +225,26 @@ void ESolver_KS_PW::before_all_runners(UnitCell& ucell, const Input_p this->pelec->wg, this->pelec->skip_weights); } + + + // 10) initialize exx pw + if (PARAM.inp.calculation == "scf" + || PARAM.inp.calculation == "relax" + || PARAM.inp.calculation == "cell-relax" + || PARAM.inp.calculation == "md") + { + if (GlobalC::exx_info.info_global.cal_exx && GlobalC::exx_info.info_global.separate_loop == true) + { + XC_Functional::set_xc_first_loop(ucell); + exx_helper.set_firstiter(); + } + + if (GlobalC::exx_info.info_global.cal_exx) + { + exx_helper.set_wg(&this->pelec->wg); + } + } + } template @@ -258,6 +278,19 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) // allocate HamiltPW this->allocate_hamilt(ucell); + if (PARAM.inp.calculation == "scf" + || PARAM.inp.calculation == "relax" + || PARAM.inp.calculation == "cell-relax" + || PARAM.inp.calculation == "md") + { + if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw") + { + auto hamilt_pw = reinterpret_cast*>(this->p_hamilt); + hamilt_pw->set_exx_helper(exx_helper); + } + + } + //---------------------------------------------------------- //! calculate the total local pseudopotential in real space //---------------------------------------------------------- @@ -333,6 +366,18 @@ void ESolver_KS_PW::before_scf(UnitCell& ucell, const int istep) this->already_initpsi = true; } + if (PARAM.inp.calculation == "scf" + || PARAM.inp.calculation == "relax" + || PARAM.inp.calculation == "cell-relax" + || PARAM.inp.calculation == "md") + { + if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw") + { + exx_helper.set_psi(kspw_psi); + } + + } + ModuleBase::timer::tick("ESolver_KS_PW", "before_scf"); } @@ -504,6 +549,11 @@ void ESolver_KS_PW::update_pot(UnitCell& ucell, const int istep, cons template void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int& iter, bool& conv_esolver) { + if (GlobalC::exx_info.info_global.cal_exx && !exx_helper.op_exx->first_iter) + { + this->pelec->set_exx(exx_helper.cal_exx_energy(kspw_psi)); + } + // deband is calculated from "output" charge density calculated // in sum_band // need 'rho(out)' and 'vr (v_h(in) and v_xc(in))' @@ -522,6 +572,33 @@ void ESolver_KS_PW::iter_finish(UnitCell& ucell, const int istep, int this->ppcell.cal_effective_D(veff, this->pw_rhod, ucell); } + if (GlobalC::exx_info.info_global.cal_exx) + { + if (GlobalC::exx_info.info_global.separate_loop) + { + if (conv_esolver) + { + exx_helper.set_psi(this->kspw_psi); + + conv_esolver = exx_helper.exx_after_converge(iter); + + if (!conv_esolver) + { + std::cout << " Setting Psi for EXX PW Inner Loop" << std::endl; + exx_helper.op_exx->first_iter = false; + XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); + update_pot(ucell, istep, iter, conv_esolver); + } + } + } + else + { +// std::cout << "setting psi for each iter" << std::endl; + exx_helper.set_psi(this->kspw_psi); + } + + } + // 3) Print out electronic wavefunctions in pw basis if (PARAM.inp.out_wfc_pw == 1 || PARAM.inp.out_wfc_pw == 2) { diff --git a/source/module_esolver/esolver_ks_pw.h b/source/module_esolver/esolver_ks_pw.h index 47ddb82e2a..d046bd28da 100644 --- a/source/module_esolver/esolver_ks_pw.h +++ b/source/module_esolver/esolver_ks_pw.h @@ -3,6 +3,8 @@ #include "./esolver_ks.h" #include "module_hamilt_pw/hamilt_pwdft/operator_pw/velocity_pw.h" #include "module_psi/psi_init.h" +#include "module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" #include #include @@ -31,6 +33,8 @@ class ESolver_KS_PW : public ESolver_KS void after_all_runners(UnitCell& ucell) override; + Exx_Helper exx_helper; + protected: virtual void before_scf(UnitCell& ucell, const int istep) override; diff --git a/source/module_hamilt_general/module_xc/xc_functional.cpp b/source/module_hamilt_general/module_xc/xc_functional.cpp index 5fc030ec13..3cf38afe03 100644 --- a/source/module_hamilt_general/module_xc/xc_functional.cpp +++ b/source/module_hamilt_general/module_xc/xc_functional.cpp @@ -40,6 +40,9 @@ method. */ else if (ucell.atoms[0].ncpp.xc_func == "SCAN0") { XC_Functional::set_xc_type("scan"); } + else if (ucell.atoms[0].ncpp.xc_func == "B3LYP") { + XC_Functional::set_xc_type("blyp"); + } } // The setting values of functional id according to the index in LIBXC @@ -244,6 +247,12 @@ void XC_Functional::set_xc_type(const std::string xc_func_in) func_type = 2; use_libxc = true; } + else if (xc_func == "B3LYP") + { + func_id.push_back(XC_HYB_GGA_XC_B3LYP); + func_type = 4; + use_libxc = true; + } #endif else { @@ -268,19 +277,19 @@ void XC_Functional::set_xc_type(const std::string xc_func_in) std::cerr << "\n OPTX untested please test,"; } - if((func_type == 4 || func_type == 5) && PARAM.inp.basis_type == "pw") - { - ModuleBase::WARNING_QUIT("set_xc_type","hybrid functional not realized for planewave yet"); - } + // if((func_type == 4 || func_type == 5) && PARAM.inp.basis_type == "pw") + // { + // ModuleBase::WARNING_QUIT("set_xc_type","hybrid functional not realized for planewave yet"); + // } if((func_type == 3 || func_type == 5) && PARAM.inp.nspin==4) { ModuleBase::WARNING_QUIT("set_xc_type","meta-GGA has not been implemented for nspin = 4 yet"); } #ifndef __EXX - if(func_type == 4 || func_type == 5) + if((func_type == 4 || func_type == 5) && PARAM.inp.basis_type == "lcao") { - ModuleBase::WARNING_QUIT("set_xc_type","compile with libri to use hybrid functional"); + ModuleBase::WARNING_QUIT("set_xc_type","compile with libri to use hybrid functional in lcao basis"); } #endif diff --git a/source/module_hamilt_general/operator.h b/source/module_hamilt_general/operator.h index 80ed065ccc..941f3e29c1 100644 --- a/source/module_hamilt_general/operator.h +++ b/source/module_hamilt_general/operator.h @@ -18,6 +18,7 @@ enum class calculation_type pw_veff, pw_meta, pw_onsite, + pw_exx, lcao_overlap, lcao_fixed, lcao_gint, @@ -41,7 +42,6 @@ class Operator // this is the core function for Operator // do H|psi> from input |psi> , - /// as default, different operators donate hPsi independently /// run this->act function for the first operator and run all act() for other nodes in chain table /// if this procedure is not suitable for your operator, just override this function. @@ -90,6 +90,11 @@ class Operator return this->act_type; } + calculation_type get_cal_type() const + { + return this->cal_type; + } + protected: int ik = 0; int act_type = 1; ///< determine which act() interface would be called in hPsi() diff --git a/source/module_hamilt_lcao/module_tddft/propagator.h b/source/module_hamilt_lcao/module_tddft/propagator.h index ca4bec2e66..cb77300b4a 100644 --- a/source/module_hamilt_lcao/module_tddft/propagator.h +++ b/source/module_hamilt_lcao/module_tddft/propagator.h @@ -84,6 +84,7 @@ ct::Tensor create_identity_matrix(const int n, ct::DeviceType device = ct::Devic data_ptr[i * n + i] = init_value(); } } +#if ((defined __CUDA)) else if (device == ct::DeviceType::GpuDevice) { // For GPU, we need to use a kernel to set the diagonal elements @@ -94,6 +95,7 @@ ct::Tensor create_identity_matrix(const int n, ct::DeviceType device = ct::Devic ct::kernels::set_memory()(data_ptr + i * n + i, value, 1); } } +#endif return tensor; } diff --git a/source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt b/source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt index 9e797f3744..50dad68591 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt +++ b/source/module_hamilt_pw/hamilt_pwdft/CMakeLists.txt @@ -9,6 +9,7 @@ list(APPEND objects operator_pw/velocity_pw.cpp operator_pw/operator_pw.cpp operator_pw/onsite_proj_pw.cpp + operator_pw/op_exx_pw.cpp forces_nl.cpp forces_cc.cpp forces_scc.cpp @@ -38,7 +39,7 @@ list(APPEND objects fs_nonlocal_tools.cpp fs_kin_tools.cpp radial_proj.cpp - onsite_projector.cpp + onsite_projector.cpp onsite_proj_tools.cpp ) @@ -46,6 +47,7 @@ add_library( hamilt_pwdft OBJECT ${objects} + stress_func_exx.cpp ) if(ENABLE_COVERAGE) diff --git a/source/module_hamilt_pw/hamilt_pwdft/global.h b/source/module_hamilt_pw/hamilt_pwdft/global.h index 8eb1bd91d4..d6f6be697f 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/global.h +++ b/source/module_hamilt_pw/hamilt_pwdft/global.h @@ -255,9 +255,9 @@ static const char* _hipfftGetErrorString(hipfftResult_t error) //========================================================== namespace GlobalC { -#ifdef __EXX +//#ifdef __EXX extern Exx_Info exx_info; -#endif +//#endif } // namespace GlobalC #include "module_cell/parallel_kpoints.h" diff --git a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp index 6612ed8ed5..6f8eead7da 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp @@ -11,6 +11,7 @@ #include "operator_pw/meta_pw.h" #include "operator_pw/nonlocal_pw.h" #include "operator_pw/onsite_proj_pw.h" +#include "operator_pw/op_exx_pw.h" #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" @@ -121,6 +122,19 @@ HamiltPW::HamiltPW(elecstate::Potential* pot_in, = new OnsiteProj>(isk, ucell, PARAM.inp.sc_mag_switch, (PARAM.inp.dft_plus_u>0)); this->ops->add(onsite_proj); } + if (GlobalC::exx_info.info_global.cal_exx) + { + auto exx = new OperatorEXXPW(isk, wfc_basis, pot_in->get_rho_basis(), pkv, ucell); + if (this->ops == nullptr) + { + this->ops = exx; + } + else + { + this->ops->add(exx); + // exx->set_psi(&this->psi); + } + } return; } @@ -382,6 +396,22 @@ void HamiltPW::sPsi(const T* psi_in, // psi } } +template +void HamiltPW::set_exx_helper(Exx_Helper &exx_helper) +{ + auto op = this->ops; + while (op != nullptr) + { + if (op->get_cal_type() == calculation_type::pw_exx) + { + exx_helper.op_exx = reinterpret_cast*>(op); + exx_helper.set_op(); + + } + op = op->next_op; + } +} + template class HamiltPW, base_device::DEVICE_CPU>; template class HamiltPW, base_device::DEVICE_CPU>; // template HamiltPW, base_device::DEVICE_CPU>::HamiltPW(const HamiltPW, diff --git a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h index badeae0db6..777dc3c9df 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h @@ -4,9 +4,11 @@ #include "module_base/macros.h" #include "module_cell/klist.h" #include "module_elecstate/potentials/potential_new.h" +#include "module_esolver/esolver_ks_pw.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_base/kernels/math_kernel_op.h" +#include "module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h" namespace hamilt { @@ -35,6 +37,8 @@ class HamiltPW : public Hamilt const int nbands // number of bands ) const override; + void set_exx_helper(Exx_Helper& exx_helper_in); + protected: // used in sPhi, which are calculated in hPsi or sPhi const pseudopot_cell_vnl* ppcell = nullptr; diff --git a/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp b/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp new file mode 100644 index 0000000000..a0982737d2 --- /dev/null +++ b/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.cpp @@ -0,0 +1,56 @@ +#include "exx_helper.h" + +template +double Exx_Helper::cal_exx_energy(psi::Psi *psi_) +{ + return op_exx->cal_exx_energy(psi_); + +} + +template +bool Exx_Helper::exx_after_converge(int &iter) +{ + if (op_exx->first_iter) + { + op_exx->first_iter = false; + } + else if (!GlobalC::exx_info.info_global.separate_loop) + { + return true; + } + else if (iter == 1) + { + return true; + } + GlobalV::ofs_running << "Updating EXX and rerun SCF" << std::endl; + iter = 0; + return false; + +} + +template +void Exx_Helper::set_psi(psi::Psi *psi_) +{ + if (psi_ == nullptr) + return; + op_exx->set_psi(*psi_); + if (PARAM.inp.exxace) + { + op_exx->construct_ace(); + } +} + +template class Exx_Helper, base_device::DEVICE_CPU>; +template class Exx_Helper, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class Exx_Helper, base_device::DEVICE_GPU>; +template class Exx_Helper, base_device::DEVICE_GPU>; +#endif + +#ifndef __EXX +#include "module_hamilt_general/module_xc/exx_info.h" +namespace GlobalC +{ + Exx_Info exx_info; +} +#endif \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h b/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h new file mode 100644 index 0000000000..4464bdab0a --- /dev/null +++ b/source/module_hamilt_pw/hamilt_pwdft/module_exx_helper/exx_helper.h @@ -0,0 +1,43 @@ +// +// For EXX in PW. +// +#include "module_psi/psi.h" +#include "module_base/matrix.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" +#include "module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h" + +#ifndef EXX_HELPER_H +#define EXX_HELPER_H +template +struct Exx_Helper +{ + using Real = typename GetTypeReal::type; + using OperatorEXX = hamilt::OperatorEXXPW; + + public: + Exx_Helper() = default; + OperatorEXX *op_exx; + + void set_firstiter() { first_iter = true; } + void set_wg(const ModuleBase::matrix *wg_) { wg = wg_; } + void set_psi(psi::Psi *psi_); + + void set_op() + { + op_exx->first_iter = first_iter; + set_psi(psi); + op_exx->set_wg(wg); + } + + bool exx_after_converge(int &iter); + + double cal_exx_energy(psi::Psi *psi_); + + private: + bool first_iter; + psi::Psi *psi; + const ModuleBase::matrix *wg; + + +}; +#endif // EXX_HELPER_H diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp new file mode 100644 index 0000000000..4b6d36908a --- /dev/null +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp @@ -0,0 +1,919 @@ +#include "module_base/constants.h" +#include "module_base/global_variable.h" +#include "module_base/parallel_reduce.h" +#include "module_base/timer.h" +#include "module_cell/klist.h" +#include "module_hamilt_general/operator.h" +#include "module_psi/psi.h" +#include "module_base/tool_quit.h" + +#include +#include +#include +#include +#include + +extern "C" +{ + void ztrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); + void ctrtri_(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info); +} + +//extern "C" void zpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); +//extern "C" void cpotrf_(char* uplo, const int* n, std::complex* A, const int* lda, int* info); + +#include "op_exx_pw.h" +#include "module_hamilt_pw/hamilt_pwdft/global.h" + +namespace hamilt +{ +template +struct trtri_op +{ + void operator()(char *uplo, char *diag, int *n, T *a, int *lda, int *info) + { + std::cout << "trtri_op not implemented" << std::endl; + } +}; + +template +struct potrf_op +{ + void operator()(char *uplo, int *n, T *a, int *lda, int *info) + { + std::cout << "potrf_op not implemented" << std::endl; + } +}; + +template +OperatorEXXPW::OperatorEXXPW(const int* isk_in, + const ModulePW::PW_Basis_K* wfcpw_in, + const ModulePW::PW_Basis* rhopw_in, + K_Vectors *kv_in, + const UnitCell *ucell) + : isk(isk_in), wfcpw(wfcpw_in), rhopw(rhopw_in), kv(kv_in), ucell(ucell) +{ + + if (GlobalV::KPAR != 1) + { + // GlobalV::ofs_running << "EXX Calculation does not support k-point parallelism" << std::endl; + ModuleBase::WARNING_QUIT("OperatorEXXPW", "EXX Calculation does not support k-point parallelism"); + } + + this->classname = "OperatorEXXPW"; + this->ctx = nullptr; + this->cpu_ctx = nullptr; + this->cal_type = hamilt::calculation_type::pw_exx; + + // allocate real space memory + // assert(wfcpw->nrxx == rhopw->nrxx); + resmem_complex_op()(psi_nk_real, wfcpw->nrxx); + resmem_complex_op()(psi_mq_real, wfcpw->nrxx); + resmem_complex_op()(density_real, rhopw->nrxx); + resmem_complex_op()(h_psi_real, rhopw->nrxx); + // allocate density recip space memory + resmem_complex_op()(density_recip, rhopw->npw); + // allocate h_psi recip space memory + resmem_complex_op()(h_psi_recip, wfcpw->npwk_max); + // resmem_complex_op()(this->ctx, psi_all_real, wfcpw->nrxx * GlobalV::NBANDS); + int nks = wfcpw->nks; +// std::cout << "nks: " << nks << std::endl; + resmem_real_op()(pot, rhopw->npw * nks * nks); + + tpiba = ucell->tpiba; + Real tpiba2 = tpiba * tpiba; + // calculate the exx_divergence + exx_divergence(); + +} + +template +OperatorEXXPW::~OperatorEXXPW() +{ + // use delete_memory_op to delete the allocated pws + delmem_complex_op()(psi_nk_real); + delmem_complex_op()(psi_mq_real); + delmem_complex_op()(density_real); + delmem_complex_op()(h_psi_real); + delmem_complex_op()(density_recip); + delmem_complex_op()(h_psi_recip); + + delmem_real_op()(pot); + + delmem_complex_op()(h_psi_ace); + delmem_complex_op()(psi_h_psi_ace); + delmem_complex_op()(L_ace); + for (auto &Xi_ace: Xi_ace_k) + { + delmem_complex_op()(Xi_ace); + } + Xi_ace_k.clear(); + +} + +template +inline bool is_finite(const T &val) +{ + return std::isfinite(val); +} + +template <> +inline bool is_finite(const std::complex &val) +{ + return std::isfinite(val.real()) && std::isfinite(val.imag()); +} + +template <> +inline bool is_finite(const std::complex &val) +{ + return std::isfinite(val.real()) && std::isfinite(val.imag()); +} + +template +void OperatorEXXPW::act(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik, + const bool is_first_node) const +{ + if (first_iter) return; + + if (is_first_node) + { + setmem_complex_op()(tmhpsi, 0, nbasis*nbands/npol); + } + + if (PARAM.inp.exxace) + { + act_op_ace(nbands, nbasis, npol, tmpsi_in, tmhpsi, ngk_ik, is_first_node); + } + else + { + act_op(nbands, nbasis, npol, tmpsi_in, tmhpsi, ngk_ik, is_first_node); + } +} + +template +void OperatorEXXPW::act_op(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik, + const bool is_first_node) const +{ + if (!potential_got) + { + get_potential(); + potential_got = true; + } + +// set_psi(&p_exx_helper->psi); + + ModuleBase::timer::tick("OperatorEXXPW", "act_op"); + + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); + setmem_complex_op()(density_real, 0, rhopw->nrxx); + setmem_complex_op()(density_recip, 0, rhopw->npw); + // setmem_complex_op()(psi_all_real, 0, wfcpw->nrxx * GlobalV::NBANDS); + // std::map, bool> has_real; + setmem_complex_op()(psi_nk_real, 0, wfcpw->nrxx); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + + // ik fixed here, select band n + for (int n_iband = 0; n_iband < nbands; n_iband++) + { + const T *psi_nk = tmpsi_in + n_iband * nbasis; + // retrieve \psi_nk in real space + wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik); + + // for \psi_nk, get the pw of iq and band m + auto q_points = get_q_points(this->ik); + Real nqs = q_points.size(); + for (int iq: q_points) + { + for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++) + { + // double wg_mqb_real = GlobalC::exx_helper.wg(iq, m_iband); + double wg_mqb_real = (*wg)(this->ik, m_iband); + T wg_mqb = wg_mqb_real; + if (wg_mqb_real < 1e-12) + { + continue; + } + + // if (has_real.find({iq, m_iband}) == has_real.end()) + // { + const T* psi_mq = get_pw(m_iband, iq); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + // syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx); + // has_real[{iq, m_iband}] = true; + // } + // else + // { + // // const T* psi_mq = get_pw(m_iband, iq); + // // wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + // syncmem_complex_op()(this->ctx, this->ctx, psi_mq_real, psi_all_real + m_iband * wfcpw->nrxx, wfcpw->nrxx); + // } + + // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ir = 0; ir < wfcpw->nrxx; ir++) + { + // assert(is_finite(psi_nk_real[ir])); + // assert(is_finite(psi_mq_real[ir])); + Real ucell_omega = ucell->omega; + density_real[ir] = psi_nk_real[ir] * std::conj(psi_mq_real[ir]) / ucell_omega; // Phase e^(i(q-k)r) + } + // to be changed into kernel function + + // bring the density to recip space + rhopw->real2recip(density_real, density_recip); + + // multiply the density with the potential in recip space + multiply_potential(density_recip, this->ik, iq); + + // bring the potential back to real space + rhopw->recip2real(density_recip, density_real); + + // get the h|psi_ik>(r), save in density_real + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ir = 0; ir < wfcpw->nrxx; ir++) + { + // assert(is_finite(psi_mq_real[ir])); + // assert(is_finite(density_real[ir])); + density_real[ir] *= psi_mq_real[ir]; + } + + T wk_iq = kv->wk[iq]; + T wk_ik = kv->wk[this->ik]; + + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ir = 0; ir < wfcpw->nrxx; ir++) + { + h_psi_real[ir] += density_real[ir] * wg_mqb / wk_iq / nqs; + } + + } // end of m_iband + setmem_complex_op()(density_real, 0, rhopw->nrxx); + setmem_complex_op()(density_recip, 0, rhopw->npw); + setmem_complex_op()(psi_mq_real, 0, wfcpw->nrxx); + + } // end of iq + auto h_psi_nk = tmhpsi + n_iband * nbasis; + Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha; + wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); + setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); + + } + + ModuleBase::timer::tick("OperatorEXXPW", "act_op"); + +} + +template +void OperatorEXXPW::act_op_ace(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik, + const bool is_first_node) const +{ + ModuleBase::timer::tick("OperatorEXXPW", "act_op_ace"); + +// std::cout << "act_op_ace" << std::endl; + // hpsi += -Xi^\dagger * Xi * psi + auto Xi_ace = Xi_ace_k[this->ik]; + int nbands_tot = psi.get_nbands(); + int nbasis_max = psi.get_nbasis(); +// T* hpsi = nullptr; +// resmem_complex_op()(hpsi, nbands_tot * nbasis); +// setmem_complex_op()(hpsi, 0, nbands_tot * nbasis); + T* Xi_psi = nullptr; + resmem_complex_op()(Xi_psi, nbands_tot * nbands); + setmem_complex_op()(Xi_psi, 0, nbands_tot * nbands); + + char trans_N = 'N', trans_T = 'T', trans_C = 'C'; + T intermediate_one = 1.0, intermediate_zero = 0.0, intermediate_minus_one = -1.0; + // Xi * psi + gemm_complex_op()(trans_N, + trans_N, + nbands_tot, + nbands, + nbasis, + &intermediate_one, + Xi_ace, + nbands_tot, + tmpsi_in, + nbasis, + &intermediate_zero, + Xi_psi, + nbands_tot + ); + + Parallel_Reduce::reduce_pool(Xi_psi, nbands_tot * nbands); + + // Xi^\dagger * (Xi * psi) + gemm_complex_op()(trans_C, + trans_N, + nbasis, + nbands, + nbands_tot, + &intermediate_minus_one, + Xi_ace, + nbands_tot, + Xi_psi, + nbands_tot, + &intermediate_one, + tmhpsi, + nbasis + ); + + +// // negative sign, add to hpsi +// vec_add_vec_complex_op()(this->ctx, nbands * nbasis, tmhpsi, hpsi, -1, tmhpsi, 1); +// delmem_complex_op()(hpsi); + delmem_complex_op()(Xi_psi); + ModuleBase::timer::tick("OperatorEXXPW", "act_op"); + +} + +template +void OperatorEXXPW::construct_ace() const +{ +// int nkb = p_exx_helper->psi.get_nbands() * p_exx_helper->psi.get_nk(); + int nbands = psi.get_nbands(); + int nbasis = psi.get_nbasis(); + int nk = psi.get_nk(); + + int ik_save = this->ik; + int * ik_ = const_cast(&this->ik); + + T intermediate_one = 1.0, intermediate_zero = 0.0; + + if (h_psi_ace == nullptr) + { + resmem_complex_op()(h_psi_ace, nbands * nbasis); + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + } + + if (Xi_ace_k.size() != nk) + { + Xi_ace_k.resize(nk); + for (int i = 0; i < nk; i++) + { + resmem_complex_op()(Xi_ace_k[i], nbands * nbasis); + } + } + + for (int i = 0; i < nk; i++) + { + setmem_complex_op()(Xi_ace_k[i], 0, nbands * nbasis); + } + + if (L_ace == nullptr) + { + resmem_complex_op()(L_ace, nbands * nbands); + setmem_complex_op()(L_ace, 0, nbands * nbands); + } + + if (psi_h_psi_ace == nullptr) + { + resmem_complex_op()(psi_h_psi_ace, nbands * nbands); + } + + for (int ik = 0; ik < nk; ik++) + { + int npwk = wfcpw->npwk[ik]; + + T* Xi_ace = Xi_ace_k[ik]; + psi.fix_kb(ik, 0); + T* p_psi = psi.get_pointer(); + + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + + *ik_ = ik; + + act_op( + nbands, + nbasis, + 1, + p_psi, + h_psi_ace, + nbasis, + false + ); + + // psi_h_psi_ace = psi^\dagger * h_psi_ace + // p_exx_helper->psi.fix_kb(0, 0); + gemm_complex_op()('C', + 'N', + nbands, + nbands, + npwk, + &intermediate_one, + p_psi, + nbasis, + h_psi_ace, + nbasis, + &intermediate_zero, + psi_h_psi_ace, + nbands); + + // reduction of psi_h_psi_ace, due to distributed memory + Parallel_Reduce::reduce_pool(psi_h_psi_ace, nbands * nbands); + + // L_ace = cholesky(-psi_h_psi_ace) + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int i = 0; i < nbands; i++) + { + for (int j = 0; j < nbands; j++) + { + L_ace[i * nbands + j] = -psi_h_psi_ace[i * nbands + j]; + } + } + + int info = 0; + char up = 'U', lo = 'L'; + + potrf_op()(&lo, &nbands, L_ace, &nbands, &info); + + // expand for-loop + #ifdef _OPENMP + #pragma omp parallel for schedule(static) collapse(2) + #endif + for (int i = 0; i < nbands; i++) + { + for (int j = 0; j < nbands; j++) + { + if (j < i) + { + // L_ace[j * nkb + i] = std::conj(L_ace[i * nkb + j]); + L_ace[i * nbands + j] = 0.0; + } + } + } + + // L_ace inv in place + // T == std::complex or std::complex + char non = 'N'; + trtri_op()(&lo, &non, &nbands, L_ace, &nbands, &info); + + // Xi_ace = L_ace^-1 * h_psi_ace^dagger + gemm_complex_op()('N', + 'C', + nbands, + npwk, + nbands, + &intermediate_one, + L_ace, + nbands, + h_psi_ace, + nbasis, + &intermediate_zero, + Xi_ace, + nbands); + + // clear mem + setmem_complex_op()(h_psi_ace, 0, nbands * nbasis); + setmem_complex_op()(psi_h_psi_ace, 0, nbands * nbands); + setmem_complex_op()(L_ace, 0, nbands * nbands); + + } + + *ik_ = ik_save; + +} + +template +std::vector OperatorEXXPW::get_q_points(const int ik) const +{ + // stored in q_points + if (q_points.find(ik) != q_points.end()) + { + return q_points.find(ik)->second; + } + + std::vector q_points_ik; + + // if () // downsampling + { + for (int iq = 0; iq < wfcpw->nks; iq++) + { + q_points_ik.push_back(iq); + } + } + // else + // { + // for (int iq = 0; iq < wfcpw->nks; iq++) + // { + // kv-> + // } + // } + + q_points[ik] = q_points_ik; + return q_points_ik; +} + +template +void OperatorEXXPW::multiply_potential(T *density_recip, int ik, int iq) const +{ + ModuleBase::timer::tick("OperatorEXXPW", "multiply_potential"); + int npw = rhopw->npw; + int nks = wfcpw->nks; + + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ig = 0; ig < npw; ig++) + { + density_recip[ig] *= pot[ik * nks * npw + iq * npw + ig]; + + } + + ModuleBase::timer::tick("OperatorEXXPW", "multiply_potential"); +} + +template +const T *OperatorEXXPW::get_pw(const int m, const int iq) const +{ + // return pws[iq].get() + m * wfcpw->npwk[iq]; + psi.fix_kb(iq, m); + auto psi_mq = psi.get_pointer(); + return psi_mq; +} + +template +template +OperatorEXXPW::OperatorEXXPW(const OperatorEXXPW *op) +{ + // copy all the datas + this->isk = op->isk; + this->wfcpw = op->wfcpw; + this->rhopw = op->rhopw; + this->psi = op->psi; + this->ctx = op->ctx; + this->cpu_ctx = op->cpu_ctx; + resmem_complex_op()(this->ctx, psi_nk_real, wfcpw->nrxx); + resmem_complex_op()(this->ctx, psi_mq_real, wfcpw->nrxx); + resmem_complex_op()(this->ctx, density_real, rhopw->nrxx); + resmem_complex_op()(this->ctx, h_psi_real, rhopw->nrxx); + resmem_complex_op()(this->ctx, density_recip, rhopw->npw); + resmem_complex_op()(this->ctx, h_psi_recip, wfcpw->npwk_max); +// this->pws.resize(wfcpw->nks); + + +} + +template +void OperatorEXXPW::get_potential() const +{ + int nks = wfcpw->nks, npw = rhopw->npw; + double tpiba2 = tpiba * tpiba; + // calculate the pot + for (int ik = 0; ik < nks; ik++) + { + for (int iq = 0; iq < nks; iq++) + { + auto k = wfcpw->kvec_c[ik]; + auto q = wfcpw->kvec_c[iq]; + + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + Real gg = (k - q + rhopw->gcar[ig]).norm2() * tpiba2; + Real hse_omega2 = GlobalC::exx_info.info_global.hse_omega * GlobalC::exx_info.info_global.hse_omega; + // if (kqgcar2 > 1e-12) // vasp uses 1/40 of the smallest (k spacing)**2 + if (gg >= 1e-8) + { + Real fac = -ModuleBase::FOUR_PI * ModuleBase::e2 / gg; + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + pot[ik * nks * npw + iq * npw + ig] = fac * (1.0 - std::exp(-gg / 4.0 / hse_omega2)); + } + else + { + pot[ik * nks * npw + iq * npw + ig] = fac; + } + } + // } + else + { + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + pot[ik * nks * npw + iq * npw + ig] = exx_div - ModuleBase::PI * ModuleBase::e2 / hse_omega2; + } + else + { + pot[ik * nks * npw + iq * npw + ig] = exx_div; + } + } + // assert(is_finite(density_recip[ig])); + } + } + } +} + +template +void OperatorEXXPW::exx_divergence() +{ + if (GlobalC::exx_info.info_lip.lambda == 0.0) + { + return; + } + + // here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90) + double alpha = 10.0 / wfcpw->gk_ecut; + double tpiba2 = tpiba * tpiba; + double div = 0; + + // this is the \sum_q F(q) part + // temporarily for all k points, should be replaced to q points later + for (int ik = 0; ik < wfcpw->nks; ik++) + { + auto k = wfcpw->kvec_c[ik]; +#ifdef _OPENMP +#pragma omp parallel for reduction(+:div) +#endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + auto q = k + rhopw->gcar[ig]; + double qq = q.norm2(); + if (qq <= 1e-8) continue; + // else if (PARAM.inp.dft_functional == "hse") + else if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double omega = GlobalC::exx_info.info_global.hse_omega; + double omega2 = omega * omega; + div += std::exp(-alpha * qq) / qq * (1.0 - std::exp(-qq*tpiba2 / 4.0 / omega2)); + } + else + { + div += std::exp(-alpha * qq) / qq; + } + } + } + + Parallel_Reduce::reduce_pool(div); + // std::cout << "EXX div: " << div << std::endl; + + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double omega = GlobalC::exx_info.info_global.hse_omega; + div += tpiba2 / 4.0 / omega / omega; // compensate for the finite value when qq = 0 + } + else + { + div -= alpha; + } + + div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / wfcpw->nks; + + // numerically value the mean value of F(q) in the reciprocal space + // This means we need to calculate the average of F(q) in the first brillouin zone + alpha /= tpiba2; + int nqq = 100000; + double dq = 5.0 / std::sqrt(alpha) / nqq; + double aa = 0.0; + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double omega = GlobalC::exx_info.info_global.hse_omega; + double omega2 = omega * omega; +#ifdef _OPENMP +#pragma omp parallel for reduction(+:aa) +#endif + for (int i = 0; i < nqq; i++) + { + double q = dq * (i+0.5); + aa -= exp(-alpha * q * q) * exp(-q*q / 4.0 / omega2) * dq; + } + } + aa *= 8 / ModuleBase::FOUR_PI; + aa += 1.0 / std::sqrt(alpha * ModuleBase::PI); + + // printf("ucell: %p\n", ucell); + double omega = ucell->omega; + div -= ModuleBase::e2 * omega * aa; + exx_div = div * wfcpw->nks; + // std::cout << "EXX divergence: " << exx_div << std::endl; + + return; + +} + +template +double OperatorEXXPW::cal_exx_energy(psi::Psi *psi_) const +{ + if (PARAM.inp.exxace) + { + return cal_exx_energy_ace(psi_); + } + else + { + return cal_exx_energy_op(psi_); + } +} + +template +double OperatorEXXPW::cal_exx_energy_ace(psi::Psi *ppsi_) const +{ + double Eexx = 0; + + psi::Psi psi_ = *ppsi_; + int *ik_ = const_cast(&this->ik); + int ik_save = this->ik; + for (int i = 0; i < wfcpw->nks; i++) + { + setmem_complex_op()(h_psi_ace, 0, psi_.get_nbands() * psi_.get_nbasis()); + *ik_ = i; + psi_.fix_kb(i, 0); + auto psi_i = psi_.get_pointer(); + act_op_ace(psi_.get_nbands(), psi_.get_nbasis(), 1, psi_i, h_psi_ace, 0, true); + + for (int nband = 0; nband < psi_.get_nbands(); nband++) + { + psi_.fix_kb(i, nband); + auto psi_i_n = psi_.get_pointer(); + auto hpsi_i_n = h_psi_ace + nband * psi_.get_nbasis(); + double wg_i_n = (*wg)(i, nband); + // Eexx += dot(psi_i_n, h_psi_i_n) + Eexx += dot_op()(psi_.get_nbasis(), psi_i_n, hpsi_i_n, false) * wg_i_n * 2; + + } + + + } + + Parallel_Reduce::reduce_pool(Eexx); + *ik_ = ik_save; + return Eexx; +} + +template +double OperatorEXXPW::cal_exx_energy_op(psi::Psi *ppsi_) const +{ + psi::Psi psi_ = *ppsi_; + + using setmem_complex_op = base_device::memory::set_memory_op; + using delmem_complex_op = base_device::memory::delete_memory_op; + T* psi_nk_real = new T[wfcpw->nrxx]; + T* psi_mq_real = new T[wfcpw->nrxx]; + T* h_psi_recip = new T[wfcpw->npwk_max]; + T* h_psi_real = new T[wfcpw->nrxx]; + T* density_real = new T[wfcpw->nrxx]; + T* density_recip = new T[rhopw->npw]; + + if (wg == nullptr) return 0.0; + // evaluate the Eexx + // T Eexx_ik = 0.0; + double Eexx_ik_real = 0.0; + for (int ik = 0; ik < wfcpw->nks; ik++) + { + // auto k = this->pw_wfc->kvec_c[ik]; + // std::cout << k << std::endl; + for (int n_iband = 0; n_iband < psi.get_nbands(); n_iband++) + { + setmem_complex_op()(h_psi_recip, 0, wfcpw->npwk_max); + setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); + setmem_complex_op()(density_real, 0, rhopw->nrxx); + setmem_complex_op()(density_recip, 0, rhopw->npw); + + // double wg_ikb_real = GlobalC::exx_helper.wg(this->ik, n_iband); + double wg_ikb_real = (*wg)(ik, n_iband); + T wg_ikb = wg_ikb_real; + if (wg_ikb_real < 1e-12) + { + continue; + } + + // std::cout << "ik = " << ik << " nb = " << n_iband << " wg_ikb = " << wg_ikb_real << std::endl; + + // const T *psi_nk = get_pw(n_iband, ik); + psi.fix_kb(ik, n_iband); + const T* psi_nk = psi.get_pointer(); + // retrieve \psi_nk in real space + wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik); + + // for \psi_nk, get the pw of iq and band m + // q_points is a vector of integers, 0 to nks-1 + std::vector q_points; + for (int iq = 0; iq < wfcpw->nks; iq++) + { + q_points.push_back(iq); + } + double nqs = q_points.size(); + + // std::cout << "ik = " << ik << " ib = " << n_iband << " wg_kb = " << wg_ikb_real << " wk_ik = " << kv->wk[ik] << std::endl; + for (int iq: q_points) + { + for (int m_iband = 0; m_iband < psi.get_nbands(); m_iband++) + { + // double wg_f = GlobalC::exx_helper.wg(iq, m_iband); + double wg_iqb_real = (*wg)(iq, m_iband); + T wg_iqb = wg_iqb_real; + if (wg_iqb_real < 1e-12) + { + continue; + } + + // std::cout << "iq = " << iq << " mb = " << m_iband << " wg_iqb = " << wg_iqb_real << std::endl; + + psi_.fix_kb(iq, m_iband); + const T* psi_mq = psi_.get_pointer(); + // const T* psi_mq = get_pw(m_iband, iq); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + + T omega_inv = 1.0 / ucell->omega; + + // direct multiplication in real space, \psi_nk(r) * \psi_mq(r) + #ifdef _OPENMP + #pragma omp parallel for + #endif + for (int ir = 0; ir < wfcpw->nrxx; ir++) + { + // assert(is_finite(psi_nk_real[ir])); + // assert(is_finite(psi_mq_real[ir])); + density_real[ir] = psi_nk_real[ir] * std::conj(psi_mq_real[ir]) * omega_inv; + } + // to be changed into kernel function + + // bring the density to recip space + rhopw->real2recip(density_real, density_recip); + + #ifdef _OPENMP + #pragma omp parallel for reduction(+:Eexx_ik_real) + #endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + int nks = wfcpw->nks; + int npw = rhopw->npw; + Real Fac = pot[ik * nks * npw + iq * npw + ig]; + Eexx_ik_real += Fac * (density_recip[ig] * std::conj(density_recip[ig])).real() + * wg_iqb_real / nqs * wg_ikb_real / kv->wk[ik]; + } + + } // m_iband + + } // iq + + } // n_iband + + } // ik + Eexx_ik_real *= 0.5 * ucell->omega; + Parallel_Reduce::reduce_pool(Eexx_ik_real); + // std::cout << "omega = " << this_->pelec->omega << " tpiba = " << this_->pw_rho->tpiba2 << " exx_div = " << exx_div << std::endl; + + double Eexx = Eexx_ik_real; + return Eexx; +} + +template <> +void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) +{ + ctrtri_(uplo, diag, n, a, lda, info); +} + +template <> +void trtri_op, base_device::DEVICE_CPU>::operator()(char *uplo, char *diag, int *n, std::complex *a, int *lda, int *info) +{ + ztrtri_(uplo, diag, n, a, lda, info); +} + +template <> +void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) +{ + cpotrf_(uplo, n, a, lda, info); +} + +template <> +void potrf_op, base_device::DEVICE_CPU>::operator()(char *uplo, int *n, std::complex *a, int *lda, int *info) +{ + zpotrf_(uplo, n, a, lda, info); +} + +template class OperatorEXXPW, base_device::DEVICE_CPU>; +template class OperatorEXXPW, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class OperatorEXXPW, base_device::DEVICE_GPU>; +template class OperatorEXXPW, base_device::DEVICE_GPU>; +#endif + +} // namespace hamilt diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h new file mode 100644 index 0000000000..af2d65d4a8 --- /dev/null +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.h @@ -0,0 +1,150 @@ +#ifndef OPEXXPW_H +#define OPEXXPW_H + +#include "module_base/matrix.h" +#include "module_basis/module_pw/pw_basis.h" +#include "module_cell/klist.h" +#include "module_psi/psi.h" +#include "operator_pw.h" +#include "module_basis/module_pw/pw_basis_k.h" +#include "module_base/macros.h" +#include "module_base/kernels/math_kernel_op.h" +#include "module_base/blas_connector.h" + +#include +#include +#include + +namespace hamilt +{ + +template +class OperatorEXXPW : public OperatorPW +{ + private: + using Real = typename GetTypeReal::type; + + public: + OperatorEXXPW(const int* isk_in, + const ModulePW::PW_Basis_K* wfcpw_in, + const ModulePW::PW_Basis* rhopw_in, + K_Vectors* kv_in, + const UnitCell* ucell); + + template + explicit OperatorEXXPW(const OperatorEXXPW *op_exx); + + virtual ~OperatorEXXPW(); + + virtual void act(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const override; + + double cal_exx_energy(psi::Psi *psi_) const; + + void set_psi(psi::Psi &psi_in) const { psi = psi_in; } + + void set_wg(const ModuleBase::matrix *wg_in) { wg = wg_in; } + + void construct_ace() const; + + bool first_iter = false; + + private: + const int *isk = nullptr; + const ModulePW::PW_Basis_K *wfcpw = nullptr; + const ModulePW::PW_Basis *rhopw = nullptr; + const UnitCell *ucell = nullptr; + Real exx_div = 0; + Real tpiba = 0; + + std::vector get_q_points(const int ik) const; + const T *get_pw(const int m, const int iq) const; + + void multiply_potential(T *density_recip, int ik, int iq) const; + + void exx_divergence(); + + void get_potential() const; + + void act_op(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const; + + void act_op_ace(const int nbands, + const int nbasis, + const int npol, + const T *tmpsi_in, + T *tmhpsi, + const int ngk_ik = 0, + const bool is_first_node = false) const; + + double cal_exx_energy_op(psi::Psi *psi_) const; + + double cal_exx_energy_ace(psi::Psi *psi_) const; + + mutable int cnt = 0; + + mutable bool potential_got = false; + + // pws +// mutable std::vector> pws; + + // k vectors + K_Vectors *kv = nullptr; + + // psi + mutable psi::Psi psi; + const ModuleBase::matrix* wg; + + // real space memory + T *psi_nk_real = nullptr; + T *psi_mq_real = nullptr; + T *density_real = nullptr; + T *h_psi_real = nullptr; + // density recip space memory + T *density_recip = nullptr; + // h_psi recip space memory + T *h_psi_recip = nullptr; + Real *pot = nullptr; + + // Lin Lin's ACE memory, 10.1021/acs.jctc.6b00092 + mutable T* h_psi_ace = nullptr; // H \Psi, W in the paper + mutable T* psi_h_psi_ace = nullptr; // \Psi^{\dagger} H \Psi, M in the paper + mutable T* L_ace = nullptr; // cholesky(-M).L, L in the paper + mutable std::vector Xi_ace_k; // L^{-1} (H \Psi)^{\dagger}, \Xi in the paper +// mutable T* Xi_ace = nullptr; // L^{-1} (H \Psi)^{\dagger}, \Xi in the paper + + mutable std::map> q_points; + + // occupational number + const ModuleBase::matrix *p_wg; + +// mutable bool update_psi = false; + + Device *ctx = {}; + base_device::DEVICE_CPU* cpu_ctx = {}; + base_device::AbacusDevice_t device = {}; + + using setmem_complex_op = base_device::memory::set_memory_op; + using resmem_complex_op = base_device::memory::resize_memory_op; + using delmem_complex_op = base_device::memory::delete_memory_op; + using syncmem_complex_op = base_device::memory::synchronize_memory_op; + using resmem_real_op = base_device::memory::resize_memory_op; + using delmem_real_op = base_device::memory::delete_memory_op; + using gemm_complex_op = ModuleBase::gemm_op; + using vec_add_vec_complex_op = ModuleBase::constantvector_addORsub_constantVector_op; + using dot_op = ModuleBase::dot_real_op; +}; + +} // namespace hamilt + +#endif // OPEXXPW_H \ No newline at end of file diff --git a/source/module_hamilt_pw/hamilt_pwdft/stress_func_exx.cpp b/source/module_hamilt_pw/hamilt_pwdft/stress_func_exx.cpp new file mode 100644 index 0000000000..6515a1e0df --- /dev/null +++ b/source/module_hamilt_pw/hamilt_pwdft/stress_func_exx.cpp @@ -0,0 +1,261 @@ +#include "stress_pw.h" +#include "global.h" + +template +void Stress_PW::stress_exx(ModuleBase::matrix& sigma, + const ModuleBase::matrix& wg, + ModulePW::PW_Basis* rhopw, + ModulePW::PW_Basis_K* wfcpw, + const K_Vectors *p_kv, + const psi::Psi, Device>* d_psi_in, const UnitCell& ucell) +{ + // T is complex of FPTYPE, if FPTYPE is double, T is std::complex + // but if FPTYPE is std::complex, T is still std::complex + using T = std::complex; + using Real = FPTYPE; + using setmem_complex_op = base_device::memory::set_memory_op; + using resmem_complex_op = base_device::memory::resize_memory_op; + using resmem_real_op = base_device::memory::resize_memory_op; + using delmem_complex_op = base_device::memory::delete_memory_op; + using delmem_real_op = base_device::memory::delete_memory_op; + using syncmem_complex_op = base_device::memory::synchronize_memory_op; + + int nks = wfcpw->nks; + int nqs = wfcpw->nks; // currently q-points downsampling is not supported + double omega = ucell.omega; + double tpiba = ucell.tpiba; + double tpiba2 = ucell.tpiba2; + double omega_inv = 1.0 / omega; + + // allocate space + T* psi_nk_real = nullptr; + T* psi_mq_real = nullptr; + T* density_real = nullptr; + T* density_recip = nullptr; + Real* pot = nullptr; // This factor is 2x of the potential in 10.1103/PhysRevB.73.125120 + + resmem_complex_op()(psi_nk_real, wfcpw->nrxx); + resmem_complex_op()(psi_mq_real, wfcpw->nrxx); + resmem_complex_op()(density_real, rhopw->nrxx); + resmem_complex_op()(density_recip, rhopw->npw); + resmem_real_op()(pot, rhopw->npw * nks * nks); + + // prepare the coefficients + double exx_div = 0; + + // pasted from op_exx_pw.cpp + { + if (GlobalC::exx_info.info_lip.lambda == 0.0) + { + return; + } + + // here we follow the exx_divergence subroutine in q-e (PW/src/exx_base.f90) + double alpha = 10.0 / wfcpw->gk_ecut; + double div = 0; + + // this is the \sum_q F(q) part + // temporarily for all k points, should be replaced to q points later + for (int ik = 0; ik < wfcpw->nks; ik++) + { + auto k = wfcpw->kvec_c[ik]; +#ifdef _OPENMP +#pragma omp parallel for reduction(+:div) +#endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + auto q = k + rhopw->gcar[ig]; + double qq = q.norm2(); + if (qq <= 1e-8) continue; + else if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double hse_omega = GlobalC::exx_info.info_global.hse_omega; + double omega2 = hse_omega * hse_omega; + div += std::exp(-alpha * qq) / qq * (1.0 - std::exp(-qq*tpiba2 / 4.0 / omega2)); + } + else + { + div += std::exp(-alpha * qq) / qq; + } + } + } + + Parallel_Reduce::reduce_pool(div); + // std::cout << "EXX div: " << div << std::endl; + + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double hse_omega = GlobalC::exx_info.info_global.hse_omega; + div += tpiba2 / 4.0 / hse_omega / hse_omega; // compensate for the finite value when qq = 0 + } + else + { + div -= alpha; + } + + div *= ModuleBase::e2 * ModuleBase::FOUR_PI / tpiba2 / wfcpw->nks; + + // numerically value the mean value of F(q) in the reciprocal space + // This means we need to calculate the average of F(q) in the first brillouin zone + alpha /= tpiba2; + int nqq = 100000; + double dq = 5.0 / std::sqrt(alpha) / nqq; + double aa = 0.0; + // if (PARAM.inp.dft_functional == "hse") + if (GlobalC::exx_info.info_global.ccp_type == Conv_Coulomb_Pot_K::Ccp_Type::Erfc) + { + double hse_omega = GlobalC::exx_info.info_global.hse_omega; + double omega2 = hse_omega * hse_omega; +#ifdef _OPENMP +#pragma omp parallel for reduction(+:aa) +#endif + for (int i = 0; i < nqq; i++) + { + double q = dq * (i+0.5); + aa -= exp(-alpha * q * q) * exp(-q*q / 4.0 / omega2) * dq; + } + } + aa *= 8 / ModuleBase::FOUR_PI; + aa += 1.0 / std::sqrt(alpha * ModuleBase::PI); + div -= ModuleBase::e2 * omega * aa; + exx_div = div * wfcpw->nks; + // std::cout << "EXX divergence: " << exx_div << std::endl; + + } + + // prepare for the potential + for (int ik = 0; ik < nks; ik++) + { + for (int iq = 0; iq < nqs; iq++) + { + auto k = wfcpw->kvec_c[ik]; + auto q = wfcpw->kvec_c[iq]; + #ifdef _OPENMP + #pragma omp parallel for schedule(static) + #endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + FPTYPE qq = (k - q + rhopw->gcar[ig]).norm2() * tpiba2; + FPTYPE fac = -ModuleBase::FOUR_PI * ModuleBase::e2 / qq; + if (qq < 1e-8) + { + pot[ig + iq * rhopw->npw + ik * rhopw->npw * nqs] = exx_div; + } + else + { + pot[ig + iq * rhopw->npw + ik * rhopw->npw * nqs] = fac; + } + } + } + } + + // calculate the stress + + // for nk, mq + for (int ik = 0; ik < nks; ik++) + { + for (int nband = 0; nband < d_psi_in->get_nbands(); nband++) + { + if (wg(ik, nband) < 1e-12) continue; + // psi_nk in real space + d_psi_in->fix_kb(ik, nband); + T* psi_nk = d_psi_in->get_pointer(); + wfcpw->recip2real(psi_nk, psi_nk_real, ik); + + for (int iq = 0; iq < nqs; iq++) + { + for (int mband = 0; mband < d_psi_in->get_nbands(); mband++) + { + // psi_mq in real space + d_psi_in->fix_kb(iq, mband); + T* psi_mq = d_psi_in->get_pointer(); + wfcpw->recip2real(psi_mq, psi_mq_real, iq); + + // overlap density in real space + setmem_complex_op()(density_real, 0.0, rhopw->nrxx); + for (int ig = 0; ig < rhopw->nrxx; ig++) + { + density_real[ig] = psi_nk_real[ig] * std::conj(psi_mq_real[ig]) * omega_inv; + } + + // density in reciprocal space + rhopw->real2recip(density_real, density_recip); + + // really calculate the stress + + // for alpha beta + for (int alpha = 0; alpha < 3; alpha++) + { + for (int beta = alpha; beta < 3; beta++) + { + int delta_ab = (alpha == beta) ? 1 : 0; + double sigma_ab_loc = 0.0; + #ifdef _OPENMP + #pragma omp parallel for schedule(static) reduction(+:sigma_ab_loc) + #endif + for (int ig = 0; ig < rhopw->npw; ig++) + { + auto kqg = wfcpw->kvec_c[ik] - wfcpw->kvec_c[iq] + rhopw->gcar[ig]; + double kqg_alpha = kqg[alpha] * tpiba; + double kqg_beta = kqg[beta] * tpiba; + // equation 10 of 10.1103/PhysRevB.73.125120 + double density_recip2 = std::real(density_recip[ig] * std::conj(density_recip[ig])); + double pot_local = pot[ig + iq * rhopw->npw + ik * rhopw->npw * nqs]; + double _4pi_e2 = ModuleBase::FOUR_PI * ModuleBase::e2; + sigma_ab_loc += density_recip2 * pot_local * (kqg_alpha * kqg_beta * (-pot_local) / _4pi_e2 - delta_ab) ; +// if (std::abs(pot_local + 22.235163511253440) < 1e-2) +// { +// std::cout << "delta_ab: " << delta_ab << std::endl; +// std::cout << "density_recip2: " << density_recip2 << std::endl; +// std::cout << "pot_local: " << pot_local << std::endl; +// std::cout << "kqg_alpha: " << kqg_alpha << std::endl; +// std::cout << "kqg_beta: " << kqg_beta << std::endl; +// +// } + } + + // 0.5 in the following line is caused by 2x in the pot + sigma(alpha, beta) -= GlobalC::exx_info.info_global.hybrid_alpha + * 0.25 * sigma_ab_loc + * wg(ik, nband) * wg(iq, mband) / nqs / p_kv->wk[ik]; + } + } + } + } + } + } + + for (int l = 0; l < 3; l++) + { + for (int m = l + 1; m < 3; m++) + { + sigma(m, l) = sigma(l, m); + } + } + + Parallel_Reduce::reduce_all(sigma.c, sigma.nr * sigma.nc); + +//// print sigma +// for (int i = 0; i < 3; i++) +// { +// for (int j = 0; j < 3; j++) +// { +// std::cout << sigma(i, j) * ModuleBase::RYDBERG_SI / pow(ModuleBase::BOHR_RADIUS_SI, 3) * 1.0e-8 << " "; +// sigma(i, j) = 0; +// } +// std::cout << std::endl; +// } + + + delmem_complex_op()(psi_nk_real); + delmem_complex_op()(psi_mq_real); + delmem_complex_op()(density_real); + delmem_complex_op()(density_recip); + delmem_real_op()(pot); +} + +template class Stress_PW; +#if ((defined __CUDA) || (defined __ROCM)) +template class Stress_PW; +#endif diff --git a/source/module_hamilt_pw/hamilt_pwdft/stress_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/stress_pw.cpp index f0ca34ea3a..827eed53ae 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/stress_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/stress_pw.cpp @@ -49,6 +49,9 @@ void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, // DFT+U and DeltaSpin stress ModuleBase::matrix sigmaonsite; sigmaonsite.create(3, 3); + // EXX PW stress + ModuleBase::matrix sigmaexx; + sigmaexx.create(3, 3); for (int i = 0; i < 3; i++) { @@ -64,6 +67,7 @@ void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, sigmaxcc(i, j) = 0.0; sigmavdw(i, j) = 0.0; sigmaonsite(i, j) = 0.0; + sigmaexx(i, j) = 0.0; } } @@ -118,13 +122,21 @@ void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, this->stress_onsite(sigmaonsite, this->pelec->wg, wfc_basis, ucell, d_psi_in, p_symm); } + // EXX PW stress + if (GlobalC::exx_info.info_global.cal_exx) + { + this->stress_exx(sigmaexx, this->pelec->wg, rho_basis, wfc_basis, p_kv, d_psi_in, ucell); + } + + for (int ipol = 0; ipol < 3; ipol++) { for (int jpol = 0; jpol < 3; jpol++) { sigmatot(ipol, jpol) = sigmakin(ipol, jpol) + sigmahar(ipol, jpol) + sigmanl(ipol, jpol) + sigmaxc(ipol, jpol) + sigmaxcc(ipol, jpol) + sigmaewa(ipol, jpol) - + sigmaloc(ipol, jpol) + sigmavdw(ipol, jpol) + sigmaonsite(ipol, jpol); + + sigmaloc(ipol, jpol) + sigmavdw(ipol, jpol) + sigmaonsite(ipol, jpol) + + sigmaexx(ipol, jpol); } } @@ -136,9 +148,9 @@ void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, bool ry = false; ModuleIO::print_stress("TOTAL-STRESS", sigmatot, true, ry); - if (PARAM.inp.test_stress) + if (PARAM.inp.test_stress || true) { - ry = true; +// ry = true; GlobalV::ofs_running << "\n PARTS OF STRESS: " << std::endl; GlobalV::ofs_running << std::setiosflags(std::ios::showpos); GlobalV::ofs_running << std::setiosflags(std::ios::fixed) << std::setprecision(10) << std::endl; @@ -153,6 +165,10 @@ void Stress_PW::cal_stress(ModuleBase::matrix& sigmatot, { ModuleIO::print_stress("ONSITE STRESS", sigmaonsite, PARAM.inp.test_stress, ry); } + if (GlobalC::exx_info.info_global.cal_exx) + { + ModuleIO::print_stress("EXX STRESS", sigmaexx, PARAM.inp.test_stress, ry); + } ModuleIO::print_stress("TOTAL STRESS", sigmatot, PARAM.inp.test_stress, ry); } ModuleBase::timer::tick("Stress_PW", "cal_stress"); diff --git a/source/module_hamilt_pw/hamilt_pwdft/stress_pw.h b/source/module_hamilt_pw/hamilt_pwdft/stress_pw.h index eb08d8ce14..5ac9a2abcd 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/stress_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/stress_pw.h @@ -35,6 +35,16 @@ class Stress_PW : public Stress_Func const pseudopot_cell_vnl& nlpp, const UnitCell& ucell); // nonlocal part of uspp in PW basis + // exx stress due to the scaling of the lattice vectors + // see 10.1103/PhysRevB.73.125120 for details + void stress_exx(ModuleBase::matrix& sigma, + const ModuleBase::matrix& wg, + ModulePW::PW_Basis* rho_basis, + ModulePW::PW_Basis_K* wfc_basis, + const K_Vectors* p_kv, + const psi::Psi, Device>* d_psi_in, + const UnitCell& ucell); // exx stress in PW basis + const elecstate::ElecState* pelec = nullptr; }; #endif diff --git a/source/module_io/input_conv.cpp b/source/module_io/input_conv.cpp index a2240f8b8a..19c182d462 100644 --- a/source/module_io/input_conv.cpp +++ b/source/module_io/input_conv.cpp @@ -316,9 +316,6 @@ void Input_Conv::Convert() //---------------------------------------------------------- // about exx, Peize Lin add 2018-06-20 //---------------------------------------------------------- -#ifdef __EXX -#ifdef __LCAO - std::string dft_functional_lower = PARAM.inp.dft_functional; std::transform(PARAM.inp.dft_functional.begin(), PARAM.inp.dft_functional.end(), @@ -329,14 +326,20 @@ void Input_Conv::Convert() GlobalC::exx_info.info_global.cal_exx = true; GlobalC::exx_info.info_global.ccp_type = Conv_Coulomb_Pot_K::Ccp_Type::Hf; - } else if (dft_functional_lower == "hse") { + } + else if (dft_functional_lower == "hse") + { GlobalC::exx_info.info_global.cal_exx = true; GlobalC::exx_info.info_global.ccp_type = Conv_Coulomb_Pot_K::Ccp_Type::Erfc; - } else if (dft_functional_lower == "opt_orb") { + } +#ifdef __EXX + else if (dft_functional_lower == "opt_orb") + { GlobalC::exx_info.info_global.cal_exx = false; Exx_Abfs::Jle::generate_matrix = true; } +#endif // muller, power, wp22, cwp22 added by jghan, 2024-07-07 else if ( dft_functional_lower == "muller" || dft_functional_lower == "power" ) { @@ -353,11 +356,22 @@ void Input_Conv::Convert() GlobalC::exx_info.info_global.cal_exx = true; GlobalC::exx_info.info_global.ccp_type = Conv_Coulomb_Pot_K::Ccp_Type::Erfc; // use the erfc(w|r-r'|), exx just has the short-range part } + else if (dft_functional_lower == "b3lyp") + { + GlobalC::exx_info.info_global.cal_exx = true; + GlobalC::exx_info.info_global.ccp_type + = Conv_Coulomb_Pot_K::Ccp_Type::Hf; + } else { GlobalC::exx_info.info_global.cal_exx = false; } - if (GlobalC::exx_info.info_global.cal_exx || Exx_Abfs::Jle::generate_matrix || PARAM.inp.rpa) + if (GlobalC::exx_info.info_global.cal_exx +#ifdef __EXX + || Exx_Abfs::Jle::generate_matrix + || PARAM.inp.rpa +#endif + ) { // EXX case, convert all EXX related variables // GlobalC::exx_info.info_global.cal_exx = true; @@ -384,18 +398,37 @@ void Input_Conv::Convert() GlobalC::exx_info.info_ri.cauchy_stress_threshold = PARAM.inp.exx_cauchy_stress_threshold; GlobalC::exx_info.info_ri.ccp_rmesh_times = std::stod(PARAM.inp.exx_ccp_rmesh_times); +#ifdef __EXX Exx_Abfs::Jle::Lmax = PARAM.inp.exx_opt_orb_lmax; Exx_Abfs::Jle::Ecut_exx = PARAM.inp.exx_opt_orb_ecut; Exx_Abfs::Jle::tolerence = PARAM.inp.exx_opt_orb_tolerence; +#endif // EXX does not support symmetry for nspin==4 - if (PARAM.inp.calculation != "nscf" && PARAM.inp.symmetry == "1" && PARAM.inp.nspin == 4) + if (PARAM.inp.calculation != "nscf" && PARAM.inp.symmetry == "1" && PARAM.inp.nspin == 4 && PARAM.inp.basis_type == "lcao") { ModuleSymmetry::Symmetry::symm_flag = -1; } } -#endif // __LCAO -#endif // __EXX + + if (GlobalC::exx_info.info_global.cal_exx && PARAM.inp.basis_type == "pw") + { + if (ModuleSymmetry::Symmetry::symm_flag != -1) + { + ModuleBase::WARNING("Input_Conv", "EXX PW works only with symmetry=-1"); + ModuleSymmetry::Symmetry::symm_flag = -1; + } + + if (PARAM.inp.nspin != 1) + { + ModuleBase::WARNING_QUIT("Input_Conv", "EXX PW works only with nspin=1"); + } + + if (PARAM.inp.device != "cpu") + { + ModuleBase::WARNING_QUIT("Input_Conv", "EXX PW works only with device=cpu"); + } + } //---------------------------------------------------------- // reset symmetry flag to avoid error diff --git a/source/module_io/read_input_item_exx_dftu.cpp b/source/module_io/read_input_item_exx_dftu.cpp index b74611bd79..a313e08be7 100644 --- a/source/module_io/read_input_item_exx_dftu.cpp +++ b/source/module_io/read_input_item_exx_dftu.cpp @@ -32,6 +32,10 @@ void ReadInput::item_exx() { para.input.exx_hybrid_alpha = "1"; } + else if (dft_functional_lower == "b3lyp") + { + para.input.exx_hybrid_alpha = "0.2"; + } else { // no exx in scf, but will change to non-zero in // postprocess like rpa diff --git a/source/module_io/read_input_item_other.cpp b/source/module_io/read_input_item_other.cpp index b23d9c60d2..29e5f1f365 100644 --- a/source/module_io/read_input_item_other.cpp +++ b/source/module_io/read_input_item_other.cpp @@ -526,5 +526,13 @@ void ReadInput::item_others() this->add_item(item); } + // EXX PW by rhx0820, 2025-03-10 + { + Input_Item item("exxace"); + item.annotation = "whether to perform ace calculation in exxpw, default is false"; + read_sync_bool(input.exxace); + this->add_item(item); + } + } } // namespace ModuleIO \ No newline at end of file diff --git a/source/module_parameter/input_parameter.h b/source/module_parameter/input_parameter.h index 4e42017ebd..4e0e8537e4 100644 --- a/source/module_parameter/input_parameter.h +++ b/source/module_parameter/input_parameter.h @@ -637,5 +637,9 @@ struct Input_para double rdmft_power_alpha = 0.656; // the alpha parameter of power-functional, g(occ_number) = occ_number^alpha // double rdmft_wp22_omega; // the omega parameter of wp22-functional = exx_hse_omega + // ============== #Parameters (22.EXX PW) ===================== + // EXX for planewave basis, rhx0820 2025-03-10 + bool exxace = true; // exxace, exact exchange for planewave basis + }; #endif diff --git a/source/module_psi/psi.cpp b/source/module_psi/psi.cpp index 07452554b1..6fb18f3ee9 100644 --- a/source/module_psi/psi.cpp +++ b/source/module_psi/psi.cpp @@ -65,7 +65,7 @@ Psi::Psi(const int nk_in, this->nk = nk_in; this->nbands = nbd_in; this->nbasis = nbs_in; - + this->current_b = 0; this->current_k = 0; this->current_nbasis = nbs_in; @@ -226,6 +226,30 @@ void Psi::set_all_psi(const T* another_pointer, const std::size_t siz synchronize_memory_op()(this->psi, another_pointer, this->size()); } +template +Psi& Psi::operator=(const Psi& psi_in) +{ +// printf("%d\n", &psi_in); + this->ngk = psi_in.ngk; + this->nk = psi_in.get_nk(); + this->nbands = psi_in.get_nbands(); + this->nbasis = psi_in.get_nbasis(); + this->current_k = psi_in.get_current_k(); + this->current_b = psi_in.get_current_b(); + this->k_first = psi_in.get_k_first(); + // this function will copy psi_in.psi to this->psi no matter the device types of each other. + + this->resize(psi_in.get_nk(), psi_in.get_nbands(), psi_in.get_nbasis()); + base_device::memory::synchronize_memory_op()(this->psi, + psi_in.psi, + psi_in.size()); + this->psi_bias = psi_in.get_psi_bias(); + this->current_nbasis = psi_in.get_current_nbas(); + this->psi_current = this->psi + psi_in.get_psi_bias(); + + return *this; +} + template void Psi::resize(const int nks_in, const int nbands_in, const int nbasis_in) { @@ -296,8 +320,8 @@ const int& Psi::get_current_ngk() const } template -int Psi::get_npol() const -{ +int Psi::get_npol() const +{ if (PARAM.inp.nspin == 4) { return 2; diff --git a/source/module_psi/psi.h b/source/module_psi/psi.h index 9b427c070d..6ba0a33720 100644 --- a/source/module_psi/psi.h +++ b/source/module_psi/psi.h @@ -74,6 +74,9 @@ class Psi // size_t size() const {return this->psi.size();} size_t size() const; + // copy assignment operator + Psi& operator=(const Psi& psi_in); + // allocate psi for three dimensions void resize(const int nks_in, const int nbands_in, const int nbasis_in);