Skip to content

Commit 7fe3a1d

Browse files
committed
add EXX operators to Grad
1 parent 8355d80 commit 7fe3a1d

File tree

21 files changed

+426
-291
lines changed

21 files changed

+426
-291
lines changed

source/Makefile.Objects

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ ${OBJS_DELTASPIN}\
117117
${OBJS_TENSOR}\
118118
${OBJS_HSOLVER_PEXSI}\
119119
${OBJS_LR}\
120+
${OBJS_LR_GRAD}\
120121
${OBJS_RDMFT}
121122

122123
OBJS_MAIN=main.o\
@@ -750,12 +751,13 @@ OBJS_TENSOR=tensor.o\
750751
cpu_allocator.o\
751752
refcount.o
752753

753-
OBJS_LR=lr_util.o\
754+
OBJS_LR=lr_util.o\
754755
lr_util_hcontainer.o\
755756
ao_to_mo_parallel.o\
756757
ao_to_mo_serial.o\
757758
dm_trans_parallel.o\
758759
dm_trans_serial.o\
760+
dm_band.o\
759761
dmr_complex.o\
760762
operator_lr_hxc.o\
761763
operator_lr_exx.o\
@@ -766,6 +768,13 @@ OBJS_TENSOR=tensor.o\
766768
hamilt_casida.o\
767769
esolver_lrtd_lcao.o\
768770

771+
OBJS_LR_GRAD=lr_force.o\
772+
CVCX_serial.o\
773+
CVCX_parallel.o\
774+
cal_edm_from_multipliers.o\
775+
pot_grad_xc.o\
776+
esolver_lr_grad.o\
777+
769778
OBJS_RDMFT=rdmft.o\
770779
rdmft_tools.o\
771780
rdmft_pot.o\

source/module_hamilt_general/operator.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,11 @@ enum class calculation_type
2727
lcao_dftu,
2828
lcao_sc_lambda,
2929
lcao_tddft_velocity,
30-
lr_dmtrans_vo,
31-
lr_dmdiff_vo
30+
lr_dmtrans_hxc,
31+
lr_dmtrans_gxc,
32+
lr_dmdiff_hxc,
33+
lr_dmtrans_exx,
34+
lr_dmdiff_exx
3235
};
3336

3437
// Basic class for operator module,

source/module_hamilt_lcao/hamilt_lcaodft/operator_lcao/operator_lcao.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ void OperatorLCAO<TK, TR>::init(const int ik_in) {
187187
break;
188188
}
189189
case calculation_type::lcao_exx:
190+
case calculation_type::lr_dmtrans_hxc:
191+
case calculation_type::lr_dmtrans_exx:
192+
case calculation_type::lr_dmdiff_hxc:
193+
case calculation_type::lr_dmdiff_exx:
194+
case calculation_type::lr_dmtrans_gxc:
190195
{
191196
//update HR first
192197
if (!this->hr_done)

source/module_lr/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ if(ENABLE_LCAO)
1212
dm_trans/dm_trans_parallel.cpp
1313
dm_trans/dm_trans_serial.cpp
1414
dm_trans/dmr_complex.cpp
15+
dm_band/dm_band.cpp
1516
operator_casida/operator_lr_hxc.cpp
1617
operator_casida/operator_lr_exx.cpp
1718
potentials/pot_hxc_lrtd.cpp

source/module_lr/Grad/esolver_lr_grad.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void LR::ESolver_LR<T, TR>::init_pot_groundstate(const Charge& chg_gs)
8484
this->pot_gs.get()->pot_register(pot_register);
8585
XC_Functional::set_xc_type(ucell.atoms[0].ncpp.xc_func); // set XC type of the ground state
8686
this->pot_gs->init_pot(0, &chg_gs); // call update_from_charge inside
87-
if (has_local_xc(this->xc_kernel))
87+
if (LR_Util::has_local_xc(this->xc_kernel))
8888
{
8989
XC_Functional::set_xc_type(this->xc_kernel); // recover the excited state xc kernel type
9090
}
@@ -179,7 +179,7 @@ std::vector<ModuleBase::matrix> LR::ESolver_LR<T, TR>::cal_force(const int ispin
179179
#endif
180180
this->gint_, pot_weak, pot_hxc_gs_weak,
181181
this->kv, this->gd, this->paraX_, this->paraC_, this->paraMat_,
182-
has_local_xc(this->xc_kernel));
182+
this->xc_kernel);
183183
if (PARAM.inp.test_force && nocc[0] == 1)
184184
{
185185
const std::vector<ct::Tensor>& dm_diff = cal_dm_diff_pblas(this->X[0].template data<T>() + offset, this->paraX_[0], c, this->paraC_, this->nbasis, this->nocc[0], this->nvirt[0], this->paraMat_);

source/module_lr/Grad/multipliers/cal_edm_from_multipliers.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "module_lr/utils/lr_util_print.h"
55
#include "cal_multiplier_w_from_z.h"
66
#include <ATen/ops/linalg_op.h>
7+
#ifdef __EXX
8+
#include "module_lr/operator_casida/operator_lr_exx.h"
9+
#endif
710
namespace LR
811
{
912

@@ -152,7 +155,7 @@ namespace LR
152155
const std::vector<Parallel_2D>& px,
153156
const Parallel_2D& pc,
154157
const Parallel_Orbitals& pmat,
155-
const bool has_local_xc)
158+
const std::string xc_kernel)
156159
{
157160
const int nk = kv.get_nks() / nspin;
158161
// 1. calculate W multiplier
@@ -164,7 +167,7 @@ namespace LR
164167
#ifdef __EXX
165168
exx_lri, exx_alpha,
166169
#endif
167-
gint, pot_hxc_gs, kv, px, pc, p_occ_occ, pmat, has_local_xc);
170+
gint, pot_hxc_gs, kv, px, pc, p_occ_occ, pmat, xc_kernel);
168171
std::cout << "W: " << std::endl;
169172
LR_Util::print_value(W.data(), nk, p_occ_occ[0].get_col_size(), p_occ_occ[0].get_row_size());
170173

@@ -174,11 +177,15 @@ namespace LR
174177
relaxed_diff_dm, gint, pot, ucell, orb_cutoff, gd, kv, px, pc, pmat,
175178
{ 0 }, T(2.0), OperatorLRHxc<T>::MO_TO_AO_TYPE::CXC_o);
176179
#ifdef __EXX
177-
// add EXX operators here
180+
OperatorLREXX<T> op_K_exx(nspin, naos, nocc[0], nvirt[0], ucell, c,
181+
relaxed_diff_dm, exx_lri, kv, px[0], pc, pmat,
182+
2.0 * exx_alpha, OperatorLREXX<T>::MO_TO_AO_TYPE::CXC_o);
178183
#endif
179184
const int ld_vo = nk * px[0].get_local_size();
180185
std::vector<T> K_cvcx(ld_vo, 0.0);
181186
op_K_cvcx.act(/*nbands=*/1, ld_vo, /*npol=*/1, X, K_cvcx.data());
187+
if (LR::exx_kernel_list().count(xc_kernel))
188+
op_K_exx.act(/*nbands=*/1, ld_vo, /*npol=*/1, X, K_cvcx.data());
182189

183190
return cal_edm_terms_from_XZWK(X, Z, W.data(), K_cvcx.data(), eig_ext_istate, eig_ks, c, nspin, p_occ_occ[0], px[0], pc, pmat);
184191
}

source/module_lr/Grad/multipliers/cal_multiplier_w_from_z.h

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "module_basis/module_ao/parallel_orbitals.h"
88
#include <ATen/ops/linalg_op.h>
99
#ifdef __EXX
10-
#include "module_ri/Exx_LRI.h"
10+
#include "module_lr/operator_casida/operator_lr_exx.h"
1111
#endif
1212
#include "module_base/scalapack_connector.h"
1313

@@ -78,11 +78,14 @@ namespace LR
7878
const Parallel_2D& pc,
7979
const std::vector<Parallel_2D>& p_occ_occ, // < for W
8080
const Parallel_Orbitals& pmat,
81-
const bool has_local_xc,
81+
const std::string xc_kernel,
8282
const std::string& spin_type = "singlet")
8383
{
8484
ModuleBase::TITLE("cal_W_from_Z", "cal_W_from_Z");
8585
using ATYPE = typename OperatorLRHxc<T>::MO_TO_AO_TYPE;
86+
#ifdef __EXX
87+
using ATYPE_EXX = typename OperatorLREXX<T>::MO_TO_AO_TYPE;
88+
#endif
8689
const int nk = kv.get_nks() / nspin;
8790
// allocate memory for DMs
8891
elecstate::DensityMatrix<T, T> DM_trans(&pmat, 1, kv.kvec_d, nk); //DX
@@ -94,6 +97,11 @@ namespace LR
9497
OperatorLRHxc<T> op_ht(nspin, naos, nocc, nvirt, psi_ks,
9598
DM_diff_relaxed, gint, pot_hxc_gs, ucell, orb_cutoff, gd, kv, p_occ_occ, pc, pmat,
9699
{ 0 }, T(1.0), ATYPE::CC_oo);
100+
#ifdef __EXX
101+
OperatorLREXX<T> op_ht_exx(nspin, naos, nocc[0], nvirt[0], ucell, psi_ks,
102+
DM_diff_relaxed, exx_lri, kv, p_occ_occ[0], pc, pmat,
103+
exx_alpha, ATYPE_EXX::CC_oo);
104+
#endif
97105
// 2. $2\sum_{jb,kc} g^{xc}_{ia, jb, kc}X_{jb}X_{kc}$
98106
// use pointer here for polymorphism
99107
// but `weak_ptr = make_shared()` will cause a segment fault because the shared_ptr is a temporary object
@@ -105,9 +113,7 @@ namespace LR
105113
OperatorLRHxc<T> op_gxc(nspin, naos, nocc, nvirt, psi_ks,
106114
DM_trans, gint, pot_grad, ucell, orb_cutoff, gd, kv, p_occ_occ, pc, pmat,
107115
{ 0 }, T(-2.0), ATYPE::CC_oo);
108-
#ifdef __EXX
109-
// add EXX operators here
110-
#endif
116+
111117
auto cal_dm_trans = [&](const int is, const T* const x_ptr)->void //DX
112118
{
113119
const auto psi_ks_is = LR_Util::get_psi_spin(psi_ks, is, nk);
@@ -151,7 +157,12 @@ namespace LR
151157
cal_dm_diff_relaxed(0, X, Z); // relaxed difference density matrix T+DZ
152158
// the 3 terms
153159
op_ht.act(/*nband=*/1, ld_oo, /*npol=*/1, X, W); //comment out this line to test H[T+Z]=0
154-
if (has_local_xc) { op_gxc.act(/*nband=*/1, ld_oo, /*npol=*/1, X, W); }
160+
if (LR::exx_kernel_list().count(xc_kernel))
161+
op_ht_exx.act(/*nband=*/1, ld_oo, /*npol=*/1, X, W);
162+
163+
if (LR_Util::has_local_xc(xc_kernel))
164+
op_gxc.act(/*nband=*/1, ld_oo, /*npol=*/1, X, W);
165+
155166
std::cout << "W (H[T+Z]) + W(gxc) terms: " << std::endl;
156167
LR_Util::print_value(W, nk, p_occ_occ[0].get_col_size(), p_occ_occ[0].get_row_size());
157168
add_ediff_term(W, X, eig, eig_ks, nk, nocc[0], nvirt[0], px[0], p_occ_occ[0]);

source/module_lr/Grad/multipliers/hamilt_zeq_left.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
#include "module_lr/operator_casida/operator_lr_hxc.h"
66
#include "module_lr/Grad/dm_diff/dm_diff.h"
77
#include "module_basis/module_ao/parallel_orbitals.h"
8+
#ifdef __EXX
9+
#include "module_lr/operator_casida/operator_lr_exx.h"
10+
#endif
811
namespace LR
912
{
1013
template<typename T>
1114
class Z_vector_L : public HamiltLR<T>
1215
{
1316
using ATYPE = typename OperatorLRHxc<T>::MO_TO_AO_TYPE;
17+
#ifdef __EXX
18+
using ATYPE_EXX = typename OperatorLREXX<T>::MO_TO_AO_TYPE;
19+
#endif
1420
public:
1521
template<typename TGint>
1622
Z_vector_L(const std::string& xc_kernel,
@@ -51,7 +57,14 @@ namespace LR
5157
{ 0 }, 2.0, ATYPE::CC_vo);
5258
this->ops->add(op_hz);
5359
#ifdef __EXX
54-
// add EXX operators here
60+
if (exx_kernel_list().count(xc_kernel))
61+
{
62+
hamilt::Operator<T>* op_hz_exx = new OperatorLREXX<T>(nspin, naos, nocc[0], nvirt[0], ucell, psi_ks,
63+
*this->DM_trans, exx_lri, kv, pX[0], pc, pmat,
64+
2.0 * exx_alpha, //alpha; H=2K when D is symmetrized
65+
ATYPE_EXX::CC_vo);
66+
this->ops->add(op_hz_exx);
67+
}
5568
#endif
5669
this->cal_dm_trans = [&, this](const int& is, const T* X)->void
5770
{

source/module_lr/Grad/multipliers/hamilt_zeq_right.h

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,18 @@
55
#include "module_lr/potentials/pot_hxc_lrtd.h"
66
#include "module_lr/operator_casida/operator_lr_hxc.h"
77
#include "module_basis/module_ao/parallel_orbitals.h"
8+
#ifdef __EXX
9+
#include "module_lr/operator_casida/operator_lr_exx.h"
10+
#endif
811
namespace LR
912
{
1013
template<typename T>
1114
class Z_vector_R : public HamiltLR<T>
1215
{
1316
using ATYPE = typename OperatorLRHxc<T>::MO_TO_AO_TYPE;
17+
#ifdef __EXX
18+
using ATYPE_EXX = typename OperatorLREXX<T>::MO_TO_AO_TYPE;
19+
#endif
1420
public:
1521
template<typename TGint>
1622
Z_vector_R(const std::string& xc_kernel,
@@ -42,36 +48,60 @@ namespace LR
4248
gint, pot, kv, pX, pc, pmat, spin_type)
4349
{
4450
ModuleBase::TITLE("Z_vector_R", "Z_vector_R");
51+
4552
this->DM_trans = LR_Util::make_unique<elecstate::DensityMatrix<T, T>>(&pmat, 1, kv.kvec_d, this->nk);
4653
LR_Util::initialize_DMR(*this->DM_trans, pmat, ucell, gd, orb_cutoff);
4754
this->DM_diff = LR_Util::make_unique<elecstate::DensityMatrix<T, T>>(&pmat, 1, kv.kvec_d, this->nk);
4855
LR_Util::initialize_DMR(*this->DM_diff, pmat, ucell, gd, orb_cutoff);
56+
57+
// note: calculation_type cannot repeated, or it will be ignored in ops->add()
4958
// 1. $2\sum_bX_{ib}K_{ab}[D^X]-2\sum_jX_{ja}K_{ij}[D^X]$
5059
this->ops = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks,
5160
*this->DM_trans, gint, pot, ucell, orb_cutoff, gd, kv, pX, pc, pmat,
5261
{ 0 }, -2.0, ATYPE::CXC);
62+
#ifdef __EXX
63+
if (exx_kernel_list().count(xc_kernel))
64+
{
65+
hamilt::Operator<T>* op_hz_exx = new OperatorLREXX<T>(nspin, naos, nocc[0], nvirt[0], ucell, psi_ks,
66+
*this->DM_trans, exx_lri, kv, pX[0], pc, pmat,
67+
-2.0 * exx_alpha, //alpha; H=2K when D is symmetrized
68+
ATYPE_EXX::CXC, {}, hamilt::calculation_type::lr_dmtrans_exx);
69+
this->ops->add(op_hz_exx);
70+
}
71+
#endif
5372
// 2. $H_{ia}[T]$, equals to $2K_{ab}[T]$ when $T$ is symmetrized
5473
hamilt::Operator<T>* op_ht = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks,
5574
*this->DM_diff, gint, pot_hxc_gs, ucell, orb_cutoff, gd, kv, pX, pc, pmat,
56-
{ 0 }, T(-2.0), ATYPE::CC_vo, hamilt::calculation_type::lr_dmdiff_vo);
75+
{ 0 }, T(-2.0), ATYPE::CC_vo, hamilt::calculation_type::lr_dmdiff_hxc);
5776
this->ops->add(op_ht);
77+
#ifdef __EXX
78+
if (exx_kernel_list().count(xc_kernel))
79+
{
80+
hamilt::Operator<T>* op_ht_exx = new OperatorLREXX<T>(nspin, naos, nocc[0], nvirt[0], ucell, psi_ks,
81+
*this->DM_diff, exx_lri, kv, pX[0], pc, pmat,
82+
-2.0 * exx_alpha, //alpha; H=2K when D is symmetrized
83+
ATYPE_EXX::CC_vo, {}, hamilt::calculation_type::lr_dmdiff_exx);
84+
this->ops->add(op_ht_exx);
85+
}
86+
#endif
87+
5888
// 3. $2\sum_{jb,kc} g^{xc}_{ia, jb, kc}X_{jb}X_{kc}$
59-
this->pot_grad = std::make_shared<PotGradXCLR>(pot.lock()->xc_kernel_components, pot.lock()->get_rho_basis(), ucell, pot.lock()->nrxx);
60-
// !!op_gxc has some bug
61-
hamilt::Operator<T>* op_gxc = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks,
62-
*this->DM_trans, gint, this->pot_grad, ucell, orb_cutoff, gd, kv, pX, pc, pmat,
63-
{ 0 }, T(-2.0), ATYPE::CC_vo, hamilt::calculation_type::lr_dmtrans_vo);
64-
assert(op_gxc != nullptr);
65-
std::cout << "op_gxc=" << op_ht << std::endl;
66-
this->ops->add(op_gxc);
89+
if (LR_Util::has_local_xc(xc_kernel))
90+
{ // !! op_gxc has some bug now
91+
this->pot_grad = std::make_shared<PotGradXCLR>(pot.lock()->xc_kernel_components, pot.lock()->get_rho_basis(), ucell, pot.lock()->nrxx);
92+
hamilt::Operator<T>* op_gxc = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks,
93+
*this->DM_trans, gint, this->pot_grad, ucell, orb_cutoff, gd, kv, pX, pc, pmat,
94+
{ 0 }, T(-2.0), ATYPE::CC_vo, hamilt::calculation_type::lr_dmtrans_gxc);
95+
assert(op_gxc != nullptr);
96+
std::cout << "op_gxc=" << op_ht << std::endl;
97+
this->ops->add(op_gxc);
98+
}
6799
// // test: op_ht only
68100
// delete this->ops;
69101
// this->ops = new OperatorLRHxc<T>(nspin, naos, nocc, nvirt, psi_ks,
70102
// *this->DM_diff, gint, pot_hxc_gs, ucell, orb_cutoff, gd, kv, pX, pc, pmat,
71103
// { 0 }, T(-2.0), ATYPE::CC_vo);
72-
#ifdef __EXX
73-
// add EXX operators here
74-
#endif
104+
75105
this->cal_dm_trans = [&, this](const int& is, const T* X)->void
76106
{
77107
const auto psi_ks_is = LR_Util::get_psi_spin(psi_ks, is, this->nk);

0 commit comments

Comments
 (0)