Skip to content

Commit a8bb37f

Browse files
committed
extract precondition function
1 parent a2fdb95 commit a8bb37f

File tree

7 files changed

+113
-55
lines changed

7 files changed

+113
-55
lines changed

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ class PyDiagoDavSubspace
130130
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
131131
};
132132

133+
hsolver::PreOP<std::complex<double>> pre_op(precond_vec, hsolver::transfunc::qe_pw<double>);
133134
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
134-
precond_vec,
135+
hsolver::bind_pre_op(pre_op),
135136
nband,
136137
nbasis,
137138
dav_ndim,

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 4 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
using namespace hsolver;
1414

1515
template <typename T, typename Device>
16-
Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precondition_in,
16+
Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc<T>&& precondition_in,
1717
const int& nband_in,
1818
const int& nbasis_in,
1919
const int& david_ndim_in,
2020
const double& diag_thr_in,
2121
const int& diag_nmax_in,
2222
const bool& need_subspace_in,
2323
const diag_comm_info& diag_comm_in)
24-
: precondition(precondition_in), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in * david_ndim_in),
25-
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
24+
: precondition(std::forward<PreFunc<T>>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
25+
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
2626
{
2727
this->device = base_device::get_device_type<Device>(this->ctx);
2828

@@ -55,14 +55,6 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(const std::vector<Real>& precond
5555
resmem_complex_op()(this->ctx, this->vcc, this->nbase_x * this->nbase_x, "DAV::vcc");
5656
setmem_complex_op()(this->ctx, this->vcc, 0, this->nbase_x * this->nbase_x);
5757
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
58-
59-
#if defined(__CUDA) || defined(__ROCM)
60-
if (this->device == base_device::GpuDevice)
61-
{
62-
resmem_real_op()(this->ctx, this->d_precondition, nbasis_in);
63-
// syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition.data(), nbasis_in);
64-
}
65-
#endif
6658
}
6759

6860
template <typename T, typename Device>
@@ -74,13 +66,6 @@ Diago_DavSubspace<T, Device>::~Diago_DavSubspace()
7466
delmem_complex_op()(this->ctx, this->hcc);
7567
delmem_complex_op()(this->ctx, this->scc);
7668
delmem_complex_op()(this->ctx, this->vcc);
77-
78-
#if defined(__CUDA) || defined(__ROCM)
79-
if (this->device == base_device::GpuDevice)
80-
{
81-
delmem_real_op()(this->ctx, this->d_precondition);
82-
}
83-
#endif
8469
}
8570

8671
template <typename T, typename Device>
@@ -334,35 +319,7 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
334319
this->dim);
335320

336321
// "precondition!!!"
337-
std::vector<Real> pre(this->dim, 0.0);
338-
for (int m = 0; m < notconv; m++)
339-
{
340-
for (size_t i = 0; i < this->dim; i++)
341-
{
342-
// pre[i] = std::abs(this->precondition[i] - (*eigenvalue_iter)[m]);
343-
double x = std::abs(this->precondition[i] - (*eigenvalue_iter)[m]);
344-
pre[i] = 0.5 * (1.0 + x + sqrt(1 + (x - 1.0) * (x - 1.0)));
345-
}
346-
#if defined(__CUDA) || defined(__ROCM)
347-
if (this->device == base_device::GpuDevice)
348-
{
349-
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, pre.data(), this->dim);
350-
vector_div_vector_op<T, Device>()(this->ctx,
351-
this->dim,
352-
psi_iter + (nbase + m) * this->dim,
353-
psi_iter + (nbase + m) * this->dim,
354-
this->d_precondition);
355-
}
356-
else
357-
#endif
358-
{
359-
vector_div_vector_op<T, Device>()(this->ctx,
360-
this->dim,
361-
psi_iter + (nbase + m) * this->dim,
362-
psi_iter + (nbase + m) * this->dim,
363-
pre.data());
364-
}
365-
}
322+
this->precondition(psi_iter + nbase * this->dim, eigenvalue_iter->data(), this->dim, notconv);
366323

367324
// "normalize!!!" in order to improve numerical stability of subspace diagonalization
368325
std::vector<Real> psi_norm(notconv, 0.0);

source/module_hsolver/diago_dav_subspace.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <vector>
1212
#include <functional>
13+
#include "module_hsolver/precondition_funcs.h"
1314

1415
namespace hsolver
1516
{
@@ -24,7 +25,7 @@ class Diago_DavSubspace
2425
using Real = typename GetTypeReal<T>::type;
2526

2627
public:
27-
Diago_DavSubspace(const std::vector<Real>& precondition_in,
28+
Diago_DavSubspace(PreFunc<T>&& precondition_in, /// pass in a function, lambda or PreOP object
2829
const int& nband_in,
2930
const int& nbasis_in,
3031
const int& david_ndim_in,
@@ -67,9 +68,9 @@ class Diago_DavSubspace
6768
/// the maximum dimension of the reduced basis set
6869
const int nbase_x = 0;
6970

70-
/// precondition for diag
71-
const std::vector<Real>& precondition;
72-
Real* d_precondition = nullptr;
71+
/// The precondition operation, can be a function, lambda or PreOP object
72+
const PreFunc<T> precondition;
73+
// note that lambdas can only passed by value
7374

7475
/// record for how many bands not have convergence eigenvalues
7576
int notconv = 0;

source/module_hsolver/hsolver_pw.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
503503
};
504504
bool scf = this->calculation_type == "nscf" ? false : true;
505505

506-
Diago_DavSubspace<T, Device> dav_subspace(pre_condition,
506+
PreOP<T, Device> pre_op(pre_condition, transfunc::qe_pw<Real>);
507+
Diago_DavSubspace<T, Device> dav_subspace(bind_pre_op(pre_op),
507508
psi.get_nbands(),
508509
psi.get_k_first() ? psi.get_current_nbas()
509510
: psi.get_nk() * psi.get_nbasis(),
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#include <functional>
2+
#include "module_base/module_device/types.h"
3+
#include "module_hsolver/kernels/math_kernel_op.h"
4+
#include <iostream> // for debugging
5+
#include <vector>
6+
namespace hsolver
7+
{
8+
/// @brief Transforming a single value,
9+
namespace transfunc
10+
{
11+
template <typename T> T none(const T& x) { return x; }
12+
template <typename T> T qe_pw(const T& x) { return 0.5 * (1.0 + x + sqrt(1 + (x - 1.0) * (x - 1.0))); }
13+
}
14+
15+
template <typename T>
16+
using Real = typename GetTypeReal<T>::type;
17+
18+
/// @brief to be called in the iterative eigensolver.
19+
/// fixed parameters: object vector, eigenvalue, leading dimension, number of vectors
20+
template <typename T>
21+
using PreFunc = const std::function<void(T*, const Real<T>*, const size_t&, const size_t&)>;
22+
// using PreFunc = std::function<void(T*, const Real<T>*, const int&, const int&)>;
23+
24+
/// type1: Divide transfunc(precon_vec - eigen_subspace[m]) for each vector[m]
25+
///$X \to (A-\lambda I)^{-1} X$
26+
// There may be other types of operation than this one.
27+
template <typename T, typename Device = base_device::DEVICE_CPU>
28+
void div_trans_prevec_minus_eigen(T* ptr, const Real<T>* eig, const size_t& dim, const size_t& nvec,
29+
const Real<T>* const pre, Real<T>* const d_pre = nullptr, const std::function<Real<T>(const Real<T>&)>& transfunc = transfunc::none<Real<T>>)
30+
{
31+
using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<Real<T>, Device, base_device::DEVICE_CPU>;
32+
std::vector<Real<T>> pre_trans(dim, 0.0);
33+
const auto device = base_device::get_device_type<Device>({});
34+
35+
for (int m = 0; m < nvec; m++)
36+
{
37+
T* const ptr_m = ptr + m * dim;
38+
for (size_t i = 0; i < dim; i++) { pre_trans[i] = transfunc(pre[i] - eig[m]); }
39+
std::cout << std::endl;
40+
#if defined(__CUDA) || defined(__ROCM)
41+
if (device == base_device::GpuDevice)
42+
{
43+
assert(d_pre);
44+
syncmem_var_h2d_op()({}, {}, d_pre, pre_trans.data(), dim);
45+
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, d_pre);
46+
}
47+
else
48+
#endif
49+
{
50+
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, pre_trans.data());
51+
}
52+
}
53+
}
54+
55+
/// @brief A operator-like class of precondition function
56+
/// to encapsulate the pre-allocation of memory on different devices before starting the iterative eigensolver.
57+
/// One can pass the operatr() function of this class, or other custom lambdas/functions to eigensolvers.
58+
template <typename T, typename Device = base_device::DEVICE_CPU>
59+
struct PreOP
60+
{
61+
PreOP(const std::vector<Real<T>>& prevec, const std::function<Real<T>(const Real<T>&)>& transfunc = transfunc::none)
62+
: PreOP<T, Device>(prevec.data(), prevec.size(), transfunc) {}
63+
PreOP(const Real<T>* const prevec, const int& dim, const std::function<Real<T>(const Real<T>&)>& transfunc = transfunc::none)
64+
: prevec_(prevec), dim_(dim), transfunc_(transfunc),
65+
dev_(base_device::get_device_type<Device>({}))
66+
{
67+
#if defined(__CUDA) || defined(__ROCM)
68+
if (this->dev_ == base_device::GpuDevice) { resmem_real_op<T, Device>()({}, this->d_prevec_, dim_); }
69+
#endif
70+
}
71+
PreOP(const PreOP& other) = delete;
72+
~PreOP() {
73+
#if defined(__CUDA) || defined(__ROCM)
74+
if (this->dev_ == base_device::GpuDevice) { delmem_real_op<T, Device>()({}, this->d_precondition); }
75+
#endif
76+
}
77+
void operator()(T* ptr, const Real<T>* eig, const size_t& dim, const size_t& nvec) const
78+
{
79+
assert(dim <= dim_);
80+
div_trans_prevec_minus_eigen<T, Device>(ptr, eig, dim, nvec, prevec_, d_prevec_, transfunc_);
81+
}
82+
private:
83+
const Real<T>* const prevec_;
84+
const int dim_;
85+
Real<T>* d_prevec_;
86+
const std::function<Real<T>(const Real<T>&)> transfunc_;
87+
const base_device::AbacusDevice_t dev_;
88+
};
89+
90+
/// @brief Bind a PreOP object to a function
91+
template <typename T, typename Device>
92+
PreFunc<T> bind_pre_op(const PreOP<T, Device>& pre_op)
93+
{
94+
return std::bind(&PreOP<T, Device>::operator(), &pre_op, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4);
95+
}
96+
}

source/module_lr/hsolver_lrtd.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ namespace LR
7373
auto hpsi_func = [&hm](T* psi_in, T* hpsi, const int ld_psi, const int nvec) {hm.hPsi(psi_in, hpsi, ld_psi, nvec);};
7474
auto spsi_func = [&hm](const T* psi_in, T* spsi, const int ld_psi, const int nvec)
7575
{ std::memcpy(spsi, psi_in, sizeof(T) * ld_psi * nvec); };
76+
auto pre_func = [&precondition](T* ptr, const Real<T>* eig, const int& ld, const int& nvec)->void
77+
{ hsolver::div_trans_prevec_minus_eigen(ptr, eig, ld, nvec, precondition.data()); };
7678

7779
if (method == "dav")
7880
{
@@ -88,7 +90,7 @@ namespace LR
8890
}
8991
else if (method == "dav_subspace") //need refactor
9092
{
91-
hsolver::Diago_DavSubspace<T> dav_subspace(precondition,
93+
hsolver::Diago_DavSubspace<T> dav_subspace(pre_func,
9294
nband,
9395
dim,
9496
PARAM.inp.pw_diag_ndim,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
totexcitationenergyref 0.784274
1+
totexcitationenergyref 0.786881

0 commit comments

Comments
 (0)