Skip to content

Commit fcff75e

Browse files
committed
fix UTs
1 parent 4f4f356 commit fcff75e

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ class DiagoDavPrepare
9292

9393
const int dim = phi.get_current_nbas() ;
9494
const int nband = phi.get_nbands();
95-
const int ld_psi =phi.get_nbasis();
96-
hsolver::DiagoDavid<std::complex<float>> dav(precondition, nband, dim, order, false, comm_info);
95+
const int ld_psi = phi.get_nbasis();
96+
const hsolver::PreOP<std::complex<float>> pre_op(precondition, dim);
97+
hsolver::DiagoDavid<std::complex<float>> dav(pre_op.get(), nband, dim, order, false, comm_info);
9798

9899
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_NMAX = maxiter;
99100
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_THR = eps;

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class DiagoDavPrepare
9292
const int dim = phi.get_current_nbas();
9393
const int nband = phi.get_nbands();
9494
const int ld_psi = phi.get_nbasis();
95-
hsolver::DiagoDavid<double> dav(precondition, nband, dim, order, false, comm_info);
95+
const hsolver::PreOP<double> pre_op(precondition, dim);
96+
hsolver::DiagoDavid<double> dav(pre_op.get(), nband, dim, order, false, comm_info);
9697

9798
hsolver::DiagoIterAssist<double>::PW_DIAG_NMAX = maxiter;
9899
hsolver::DiagoIterAssist<double>::PW_DIAG_THR = eps;

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ class DiagoDavPrepare
9191

9292
const int dim = phi.get_current_nbas();
9393
const int nband = phi.get_nbands();
94-
const int ld_psi = phi.get_nbasis();
95-
hsolver::DiagoDavid<std::complex<double>> dav(precondition, nband, dim, order, false, comm_info);
94+
const int ld_psi = phi.get_nbasis();
95+
const auto pre_func = [&precondition](std::complex<double>* ptr, const int& ld, const int& nvec)->void
96+
{ hsolver::fvec::div_prevec(ptr, ld, nvec, precondition); };
97+
hsolver::DiagoDavid<std::complex<double>> dav(pre_func, nband, dim, order, false, comm_info);
9698

9799
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = maxiter;
98100
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR = eps;

source/module_hsolver/test/hsolver_pw_sup.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "module_basis/module_pw/pw_basis_k.h"
2+
#include "module_hsolver/precondition_funcs.h"
23

34
namespace ModulePW {
45

@@ -121,15 +122,15 @@ template class DiagoCG<std::complex<float>, base_device::DEVICE_CPU>;
121122
template class DiagoCG<std::complex<double>, base_device::DEVICE_CPU>;
122123

123124
template <typename T, typename Device>
124-
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
125+
DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
125126
const int nband_in,
126127
const int dim_in,
127128
const int david_ndim_in,
128129
const bool use_paw_in,
129130
const diag_comm_info& diag_comm_in)
130-
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) {
131+
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
132+
precondition(std::forward<PreFunc>(precondition_in)) {
131133
this->device = base_device::get_device_type<Device>(this->ctx);
132-
this->precondition = precondition_in;
133134

134135
test_david = 2;
135136
// 1: check which function is called and which step is executed

0 commit comments

Comments
 (0)