Skip to content

Commit 2a35de7

Browse files
committed
diag-precondition
1 parent 2357507 commit 2a35de7

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

source/module_lr/esolver_lrtd_lcao.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "module_base/scalapack_connector.h"
1717
#include "module_parameter/parameter.h"
1818
#include "module_lr/ri_benchmark/ri_benchmark.h"
19+
#include "module_lr/operator_casida/operator_lr_diag.h" // for precondition
1920

2021
#ifdef __EXX
2122
template<>
@@ -444,20 +445,29 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
444445
if (GlobalV::MY_RANK == 0) { assert(nst == LR_Util::write_value(efile(label), prec, e, nst)); }
445446
assert(nst * dim == LR_Util::write_value(vfile(label), prec, v, nst, dim));
446447
};
448+
std::vector<double> precondition(this->input.lr_solver == "lapack" ? 0 : nloc_per_band, 1.0);
447449
// allocate and initialize A matrix and density matrix
448450
if (openshell)
449451
{
452+
for (int is : {0, 1})
453+
{
454+
const int offset_is = is * this->paraX_[0].get_local_size();
455+
OperatorLRDiag<double> pre_op(this->eig_ks.c + is * nk * (nocc[0] + nvirt[0]), this->paraX_[is], this->nk, this->nocc[is], this->nvirt[is]);
456+
if (input.lr_solver != "lapack") { pre_op.act(1, offset_is, 1, precondition.data() + offset_is, precondition.data() + offset_is); }
457+
}
450458
std::cout << "Solving spin-conserving excitation for open-shell system." << std::endl;
451459
HamiltULR<T> hulr(xc_kernel, nspin, this->nbasis, this->nocc, this->nvirt, this->ucell, orb_cutoff_, GlobalC::GridD, *this->psi_ks, this->eig_ks,
452460
#ifdef __EXX
453461
this->exx_lri, this->exx_info.info_global.hybrid_alpha,
454462
#endif
455463
this->gint_, this->pot, this->kv, this->paraX_, this->paraC_, this->paraMat_);
456-
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr);
464+
LR::HSolver::solve(hulr, this->X[0].template data<T>(), nloc_per_band, nstates, this->pelec->ekb.c, this->input.lr_solver, this->input.lr_thr, precondition);
457465
if (input.out_wfc_lr) { write_states("openshell", this->pelec->ekb.c, this->X[0].template data<T>(), nloc_per_band, nstates); }
458466
}
459467
else
460468
{
469+
OperatorLRDiag<double> pre_op(this->eig_ks.c, this->paraX_[0], this->nk, this->nocc[0], this->nvirt[0]);
470+
if (input.lr_solver != "lapack") { pre_op.act(1, nloc_per_band, 1, precondition.data(), precondition.data()); }
461471
auto spin_types = std::vector<std::string>({ "singlet", "triplet" });
462472
for (int is = 0;is < nspin;++is)
463473
{
@@ -470,7 +480,7 @@ void LR::ESolver_LR<T, TR>::runner(int istep, UnitCell& cell)
470480
spin_types[is], input.ri_hartree_benchmark, (input.ri_hartree_benchmark == "aims" ? input.aims_nbasis : std::vector<int>({})));
471481
// solve the Casida equation
472482
LR::HSolver::solve(hlr, this->X[is].template data<T>(), nloc_per_band, nstates,
473-
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr/*,
483+
this->pelec->ekb.c + is * nstates, this->input.lr_solver, this->input.lr_thr, precondition/*,
474484
!std::set<std::string>({ "hf", "hse" }).count(this->xc_kernel)*/); //whether the kernel is Hermitian
475485
if (input.out_wfc_lr) { write_states(spin_types[is], this->pelec->ekb.c + is * nstates, this->X[is].template data<T>(), nloc_per_band, nstates); }
476486
}

source/module_lr/hsolver_lrtd.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ namespace LR
3232
double* eig,
3333
const std::string method,
3434
const Real<T>& diag_ethr, ///< threshold for diagonalization
35+
const std::vector<Real<T>>& precondition,
3536
const bool hermitian = true)
3637
{
3738
ModuleBase::TITLE("HSolverLR", "solve");
3839
const std::vector<std::string> spin_types = { "singlet", "triplet" };
3940
// note: if not TDA, the eigenvalues will be complex
4041
// then we will need a new constructor of DiagoDavid
4142

42-
// 1. allocate precondition and eigenvalue
43-
std::vector<Real<T>> precondition(dim);
43+
// 1. allocate eigenvalue
4444
std::vector<Real<T>> eigenvalue(nband); //nstates
4545
// 2. select the method
4646
#ifdef __MPI
@@ -67,9 +67,7 @@ namespace LR
6767
}
6868
else
6969
{
70-
// 3. set precondition and diagethr
71-
for (int i = 0; i < dim; ++i) { precondition[i] = static_cast<Real<T>>(1.0); }
72-
70+
// 3. set maxiter and funcs
7371
const int maxiter = hsolver::DiagoIterAssist<T>::PW_DIAG_NMAX;
7472

7573
auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) {hm.hPsi(psi_in, hpsi, ld_psi, nvec);};
@@ -139,7 +137,8 @@ namespace LR
139137

140138
auto psi_tensor = ct::TensorMap(psi, ct::DataTypeToEnum<T>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband, dim }));
141139
auto eigen_tensor = ct::TensorMap(eigenvalue.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ nband }));
142-
auto precon_tensor = ct::TensorMap(precondition.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
140+
std::vector<Real<T>> precondition_(precondition); //since TensorMap does not support const pointer
141+
auto precon_tensor = ct::TensorMap(precondition_.data(), ct::DataTypeToEnum<Real<T>>::value, ct::DeviceType::CpuDevice, ct::TensorShape({ dim }));
143142
auto hpsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& hpsi) {hm.hPsi(psi_in.data<T>(), hpsi.data<T>(), psi_in.shape().dim_size(0) /*nbasis_local*/, 1/*band-by-band*/);};
144143
auto spsi_func = [&hm](const ct::Tensor& psi_in, ct::Tensor& spsi)
145144
{ std::memcpy(spsi.data<T>(), psi_in.data<T>(), sizeof(T) * psi_in.NumElements()); };

source/module_lr/operator_casida/operator_lr_diag.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,8 @@ namespace LR
4646
const bool is_first_node = false)const override
4747
{
4848
ModuleBase::TITLE("OperatorLRDiag", "act");
49-
const int nlocal_ph = nk * pX.get_local_size(); // local size of particle-hole basis
5049
hsolver::vector_mul_vector_op<T, Device>()(this->ctx,
51-
nk * pX.get_local_size(),
50+
nk * pX.get_local_size(), // local size of particle-hole basis
5251
hpsi,
5352
psi_in,
5453
this->eig_ks_diff.c);

0 commit comments

Comments
 (0)