Skip to content

Commit 4f4f356

Browse files
committed
apply to dav
1 parent a8bb37f commit 4f4f356

File tree

10 files changed

+135
-128
lines changed

10 files changed

+135
-128
lines changed

python/pyabacus/src/hsolver/py_diago_dav_subspace.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,10 @@ class PyDiagoDavSubspace
130130
std::copy(hpsi_ptr, hpsi_ptr + nvec * ld_psi, hpsi_out);
131131
};
132132

133-
hsolver::PreOP<std::complex<double>> pre_op(precond_vec, hsolver::transfunc::qe_pw<double>);
133+
hsolver::PreOP<std::complex<double>, base_device::DEVICE_CPU, hsolver::fvec::DivTransMinusEigKernel<std::complex<double>, base_device::DEVICE_CPU>>
134+
pre_op(precond_vec, hsolver::fvec::div_trans_prevec_minus_eigen<std::complex<double>>, hsolver::fval::qe_pw<double>);
134135
obj = std::make_unique<hsolver::Diago_DavSubspace<std::complex<double>, base_device::DEVICE_CPU>>(
135-
hsolver::bind_pre_op(pre_op),
136+
pre_op.get(),
136137
nband,
137138
nbasis,
138139
dav_ndim,

python/pyabacus/src/hsolver/py_diago_david.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ class PyDiagoDavid
137137
syncmem_op()(this->ctx, this->ctx, spsi_out, psi_in, static_cast<size_t>(nbands * nrow));
138138
};
139139

140+
hsolver::PreOP<std::complex<double>> pre_op(precond_vec);
140141
obj = std::make_unique<hsolver::DiagoDavid<std::complex<double>, base_device::DEVICE_CPU>>(
141-
precond_vec.data(),
142+
pre_op.get(),
142143
nband,
143144
nbasis,
144145
dav_ndim,

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
using namespace hsolver;
1414

1515
template <typename T, typename Device>
16-
Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc<T>&& precondition_in,
16+
Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc&& precondition_in,
1717
const int& nband_in,
1818
const int& nbasis_in,
1919
const int& david_ndim_in,
2020
const double& diag_thr_in,
2121
const int& diag_nmax_in,
2222
const bool& need_subspace_in,
2323
const diag_comm_info& diag_comm_in)
24-
: precondition(std::forward<PreFunc<T>>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
24+
: precondition(std::forward<PreFunc>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
2525
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
2626
{
2727
this->device = base_device::get_device_type<Device>(this->ctx);

source/module_hsolver/diago_dav_subspace.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ class Diago_DavSubspace
2424
// otherwise return the real type of T(complex<float>, complex<double>)
2525
using Real = typename GetTypeReal<T>::type;
2626

27-
public:
28-
Diago_DavSubspace(PreFunc<T>&& precondition_in, /// pass in a function, lambda or PreOP object
27+
using PreFunc = fvec::DivTransMinusEig<T>;
28+
public:
29+
Diago_DavSubspace(PreFunc&& precondition_in, /// pass in a function, lambda or PreOP object
2930
const int& nband_in,
3031
const int& nbasis_in,
3132
const int& david_ndim_in,
@@ -69,7 +70,7 @@ class Diago_DavSubspace
6970
const int nbase_x = 0;
7071

7172
/// The precondition operation, can be a function, lambda or PreOP object
72-
const PreFunc<T> precondition;
73+
const PreFunc precondition;
7374
// note that lambdas can only passed by value
7475

7576
/// record for how many bands not have convergence eigenvalues

source/module_hsolver/diago_david.cpp

Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ using namespace hsolver;
3232
* @note Auxiliary memory is allocated in the constructor and deallocated in the destructor.
3333
*/
3434
template <typename T, typename Device>
35-
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
35+
DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
3636
const int nband_in,
3737
const int dim_in,
3838
const int david_ndim_in,
3939
const bool use_paw_in,
4040
const diag_comm_info& diag_comm_in)
41-
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in)
41+
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
42+
precondition(std::forward<PreFunc>(precondition_in))
4243
{
4344
this->device = base_device::get_device_type<Device>(this->ctx);
44-
this->precondition = precondition_in;
4545

4646
this->one = &one_;
4747
this->zero = &zero_;
@@ -110,15 +110,6 @@ DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
110110
// lagrange_matrix(nband, nband); // for orthogonalization
111111
resmem_complex_op()(this->ctx, this->lagrange_matrix, nband * nband);
112112
setmem_complex_op()(this->ctx, this->lagrange_matrix, 0, nband * nband);
113-
114-
#if defined(__CUDA) || defined(__ROCM)
115-
// device precondition array
116-
if (this->device == base_device::GpuDevice)
117-
{
118-
resmem_var_op()(this->ctx, this->d_precondition, dim);
119-
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, this->d_precondition, this->precondition, dim);
120-
}
121-
#endif
122113
}
123114

124115
/**
@@ -139,13 +130,6 @@ DiagoDavid<T, Device>::~DiagoDavid()
139130
delmem_complex_op()(this->ctx, this->vcc);
140131
delmem_complex_op()(this->ctx, this->lagrange_matrix);
141132
base_device::memory::delete_memory_op<Real, base_device::DEVICE_CPU>()(this->cpu_ctx, this->eigenvalue);
142-
// If the device is a GPU device, free the d_precondition array.
143-
#if defined(__CUDA) || defined(__ROCM)
144-
if (this->device == base_device::GpuDevice)
145-
{
146-
delmem_var_op()(this->ctx, this->d_precondition);
147-
}
148-
#endif
149133
}
150134

151135
template <typename T, typename Device>
@@ -499,40 +483,14 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
499483
dim // LDC: if(N) max(1, m)
500484
);
501485
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
502-
503486
// Preconditioning
504487
// basis[nbase] = T * basis[nbase] = T * (H - lambda * S) * psi
505488
// where T, the preconditioner, is an approximate inverse of H
506489
// T is a diagonal stored in array `precondition`
507490
// to do preconditioning, divide each column of basis by the corresponding element of precondition
508-
for (int m = 0; m < notconv; m++)
509-
{
510-
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
511-
if (this->device == base_device::GpuDevice)
512-
{
513-
#if defined(__CUDA) || defined(__ROCM)
514-
vector_div_vector_op<T, Device>()(this->ctx,
515-
dim,
516-
basis + dim*(nbase + m),
517-
basis + dim*(nbase + m),
518-
this->d_precondition);
519-
#endif
520-
}
521-
else
522-
{
523-
vector_div_vector_op<T, Device>()(this->ctx,
524-
dim,
525-
basis + dim*(nbase + m),
526-
basis + dim*(nbase + m),
527-
this->precondition);
528-
}
529-
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
530-
// for (int ig = 0; ig < dim; ig++)
531-
// {
532-
// ppsi[ig] /= this->precondition[ig];
533-
// }
534-
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
535-
}
491+
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
492+
this->precondition(basis + dim * nbase, dim, notconv);
493+
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
536494

537495
// there is a nbase to nbase + notconv band orthogonalise
538496
// plan for SchmidtOrth

source/module_hsolver/diago_david.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include "module_base/module_device/memory_op.h"// base_device::memory
77

88
#include "module_hsolver/diag_comm_info.h"
9-
9+
#include "module_hsolver/precondition_funcs.h"
1010
#include <vector>
1111
#include <functional>
1212

@@ -21,10 +21,11 @@ class DiagoDavid
2121
// return T if T is real type(float, double),
2222
// otherwise return the real type of T(complex<float>, complex<double>)
2323
using Real = typename GetTypeReal<T>::type;
24-
25-
public:
2624

27-
DiagoDavid(const Real* precondition_in,
25+
using PreFunc = fvec::Div<T>;
26+
public:
27+
28+
DiagoDavid(PreFunc&& precondition_in,
2829
const int nband_in,
2930
const int dim_in,
3031
const int david_ndim_in,
@@ -102,8 +103,7 @@ class DiagoDavid
102103
int notconv = 0;
103104

104105
/// precondition for diag, diagonal approximation of matrix A(i.e. Hamilt)
105-
const Real* precondition = nullptr;
106-
Real* d_precondition = nullptr;
106+
const PreFunc precondition;
107107

108108
/// eigenvalue results
109109
Real* eigenvalue = nullptr;

source/module_hsolver/hsolver_pw.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
503503
};
504504
bool scf = this->calculation_type == "nscf" ? false : true;
505505

506-
PreOP<T, Device> pre_op(pre_condition, transfunc::qe_pw<Real>);
507-
Diago_DavSubspace<T, Device> dav_subspace(bind_pre_op(pre_op),
506+
// const auto pre_op = make_pre_op(pre_condition, fvec::div_trans_prevec_minus_eigen<T, Device>, fval::qe_pw<Real>);
507+
const PreOP<T, Device, fvec::DivTransMinusEigKernel<T, Device>> pre_op(pre_condition, fvec::div_trans_prevec_minus_eigen<T, Device>, fval::qe_pw<Real>);
508+
Diago_DavSubspace<T, Device> dav_subspace(pre_op.get(),
508509
psi.get_nbands(),
509510
psi.get_k_first() ? psi.get_current_nbas()
510511
: psi.get_nk() * psi.get_nbasis(),
@@ -573,7 +574,9 @@ void HSolverPW<T, Device>::hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
573574
ModuleBase::timer::tick("David", "spsi_func");
574575
};
575576

576-
DiagoDavid<T, Device> david(pre_condition.data(),
577+
// const auto pre_op = make_pre_op(pre_condition, fvec::div_prevec<T, Device>);
578+
const PreOP<T, Device> pre_op(pre_condition);
579+
DiagoDavid<T, Device> david(pre_op.get(),
577580
nband,
578581
dim,
579582
PARAM.inp.pw_diag_ndim,

0 commit comments

Comments
 (0)