Skip to content

Commit 17cac27

Browse files
committed
fix UTs
1 parent 4f4f356 commit 17cac27

File tree

8 files changed

+35
-27
lines changed

8 files changed

+35
-27
lines changed

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ Diago_DavSubspace<T, Device>::Diago_DavSubspace(PreFunc&& precondition_in,
2121
const int& diag_nmax_in,
2222
const bool& need_subspace_in,
2323
const diag_comm_info& diag_comm_in)
24-
: precondition(std::forward<PreFunc>(precondition_in)), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
24+
: precondition(precondition_in), n_band(nband_in), dim(nbasis_in), nbase_x(nband_in* david_ndim_in),
2525
diag_thr(diag_thr_in), iter_nmax(diag_nmax_in), is_subspace(need_subspace_in), diag_comm(diag_comm_in)
2626
{
2727
this->device = base_device::get_device_type<Device>(this->ctx);

source/module_hsolver/diago_david.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
3939
const bool use_paw_in,
4040
const diag_comm_info& diag_comm_in)
4141
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
42-
precondition(std::forward<PreFunc>(precondition_in))
42+
precondition(precondition_in)
4343
{
4444
this->device = base_device::get_device_type<Device>(this->ctx);
4545

source/module_hsolver/precondition_funcs.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "module_base/module_device/types.h"
55
#include "module_base/module_device/memory_op.h"
66
#include "module_hsolver/kernels/math_kernel_op.h"
7+
8+
/// @brief Preconditioner Function Library
9+
/// Users can add other types of operation than the following ones at one's need.
710
namespace hsolver
811
{
912
template <typename T>
@@ -19,24 +22,21 @@ namespace hsolver
1922
/// @brief Transform vectors
2023
namespace fvec
2124
{
22-
/// @brief To be called in the iterative eigensolver.
23-
/// Users can add other types of operation than the following ones at one's need.
24-
/// fixed parameters: object vector, eigenvalue, leading dimension, number of vectors
25-
2625
///---------------------------------------------------------------------------------------------
2726
/// type 1: directly divide each vector by the precondition vector
2827
///---------------------------------------------------------------------------------------------
2928
template <typename T, typename Device = base_device::DEVICE_CPU>
3029
void div_prevec(T* ptr, const size_t& dim, const size_t& nvec,
3130
const Real<T>* const pre)
3231
{
32+
Device* ctx = {};
3333
for (int m = 0; m < nvec; m++)
3434
{
3535
T* const ptr_m = ptr + m * dim;
36-
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, pre);
36+
vector_div_vector_op<T, Device>()(ctx, dim, ptr_m, ptr_m, pre);
3737
}
3838
}
39-
/// calling intereface in the eigensolver
39+
/// Intereface to be called in the eigensolver
4040
template <typename T>
4141
using Div = std::function<void(T*, const size_t&, const size_t&)>;
4242
// Kernel function full of dependence
@@ -54,6 +54,8 @@ namespace hsolver
5454
using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<Real<T>, Device, base_device::DEVICE_CPU>;
5555
std::vector<Real<T>> pre_trans(dim, 0.0);
5656
const auto device = base_device::get_device_type<Device>({});
57+
Device* ctx = {};
58+
base_device::DEVICE_CPU* cpu_ctx = {};
5759

5860
for (int m = 0; m < nvec; m++)
5961
{
@@ -63,27 +65,28 @@ namespace hsolver
6365
if (device == base_device::GpuDevice)
6466
{
6567
assert(d_pre);
66-
syncmem_var_h2d_op()({}, {}, d_pre, pre_trans.data(), dim);
67-
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, d_pre);
68+
syncmem_var_h2d_op()(ctx, cpu_ctx, d_pre, pre_trans.data(), dim);
69+
vector_div_vector_op<T, Device>()(ctx, dim, ptr_m, ptr_m, d_pre);
6870
}
6971
else
7072
#endif
7173
{
72-
vector_div_vector_op<T, Device>()({}, dim, ptr_m, ptr_m, pre_trans.data());
74+
vector_div_vector_op<T, Device>()(ctx, dim, ptr_m, ptr_m, pre_trans.data());
7375
}
7476
}
7577
}
76-
/// calling intereface in the eigensolver
78+
/// Intereface to be called in the eigensolver
7779
template <typename T>
7880
using DivTransMinusEig = std::function<void(T*, const Real<T>*, const size_t&, const size_t&)>;
79-
// Kernel function full of dependence
81+
/// Kernel function full of dependence
8082
template <typename T, typename Device = base_device::DEVICE_CPU>
8183
using DivTransMinusEigKernel = std::function<decltype(div_trans_prevec_minus_eigen<T, Device>)>;
8284
}
8385

8486
/// @brief A operator-like class of precondition function
8587
/// to encapsulate the pre-allocation of memory on different devices before starting the iterative eigensolver.
86-
/// One can pass the operatr() function of this class, or other custom lambdas/functions to eigensolvers.
88+
/// One can use `.get()` interface to get the function to be called by the eigensovler,
89+
/// or pass a custom lambdas/function to replace the one returned by `.get()`.
8790
template <typename T, typename Device = base_device::DEVICE_CPU, typename Kernel_t = fvec::DivKernel<T, Device>>
8891
struct PreOP
8992
{

source/module_hsolver/test/CMakeLists.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ if (ENABLE_MPI)
7575
SOURCES test_hsolver_pw.cpp ../hsolver_pw.cpp ../hsolver_lcaopw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
7676
)
7777

78-
# AddTest(
79-
# TARGET HSolver_sdft
80-
# LIBS parameter ${math_libs} psi device base container
81-
# SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
82-
# )
78+
AddTest(
79+
TARGET HSolver_sdft
80+
LIBS parameter ${math_libs} psi device base container
81+
SOURCES test_hsolver_sdft.cpp ../hsolver_pw_sdft.cpp ../hsolver_pw.cpp ../diago_bpcg.cpp ../diago_dav_subspace.cpp ../diag_const_nums.cpp ../diago_iter_assist.cpp
82+
)
8383

8484
if(ENABLE_LCAO)
8585
if(USE_ELPA)

source/module_hsolver/test/diago_david_float_test.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ class DiagoDavPrepare
9292

9393
const int dim = phi.get_current_nbas() ;
9494
const int nband = phi.get_nbands();
95-
const int ld_psi =phi.get_nbasis();
96-
hsolver::DiagoDavid<std::complex<float>> dav(precondition, nband, dim, order, false, comm_info);
95+
const int ld_psi = phi.get_nbasis();
96+
const hsolver::PreOP<std::complex<float>> pre_op(precondition, dim);
97+
hsolver::DiagoDavid<std::complex<float>> dav(pre_op.get(), nband, dim, order, false, comm_info);
9798

9899
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_NMAX = maxiter;
99100
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_THR = eps;

source/module_hsolver/test/diago_david_real_test.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ class DiagoDavPrepare
9292
const int dim = phi.get_current_nbas();
9393
const int nband = phi.get_nbands();
9494
const int ld_psi = phi.get_nbasis();
95-
hsolver::DiagoDavid<double> dav(precondition, nband, dim, order, false, comm_info);
95+
const hsolver::PreOP<double> pre_op(precondition, dim);
96+
hsolver::DiagoDavid<double> dav(pre_op.get(), nband, dim, order, false, comm_info);
9697

9798
hsolver::DiagoIterAssist<double>::PW_DIAG_NMAX = maxiter;
9899
hsolver::DiagoIterAssist<double>::PW_DIAG_THR = eps;

source/module_hsolver/test/diago_david_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,10 @@ class DiagoDavPrepare
9191

9292
const int dim = phi.get_current_nbas();
9393
const int nband = phi.get_nbands();
94-
const int ld_psi = phi.get_nbasis();
95-
hsolver::DiagoDavid<std::complex<double>> dav(precondition, nband, dim, order, false, comm_info);
94+
const int ld_psi = phi.get_nbasis();
95+
const auto pre_func = [&precondition](std::complex<double>* ptr, const int& ld, const int& nvec)->void
96+
{ hsolver::fvec::div_prevec(ptr, ld, nvec, precondition); };
97+
hsolver::DiagoDavid<std::complex<double>> dav(pre_func, nband, dim, order, false, comm_info);
9698

9799
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_NMAX = maxiter;
98100
hsolver::DiagoIterAssist<std::complex<double>>::PW_DIAG_THR = eps;

source/module_hsolver/test/hsolver_pw_sup.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "module_basis/module_pw/pw_basis_k.h"
2+
#include "module_hsolver/precondition_funcs.h"
23

34
namespace ModulePW {
45

@@ -121,15 +122,15 @@ template class DiagoCG<std::complex<float>, base_device::DEVICE_CPU>;
121122
template class DiagoCG<std::complex<double>, base_device::DEVICE_CPU>;
122123

123124
template <typename T, typename Device>
124-
DiagoDavid<T, Device>::DiagoDavid(const Real* precondition_in,
125+
DiagoDavid<T, Device>::DiagoDavid(PreFunc&& precondition_in,
125126
const int nband_in,
126127
const int dim_in,
127128
const int david_ndim_in,
128129
const bool use_paw_in,
129130
const diag_comm_info& diag_comm_in)
130-
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in * nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in) {
131+
: nband(nband_in), dim(dim_in), nbase_x(david_ndim_in* nband_in), david_ndim(david_ndim_in), use_paw(use_paw_in), diag_comm(diag_comm_in),
132+
precondition(std::forward<PreFunc>(precondition_in)) {
131133
this->device = base_device::get_device_type<Device>(this->ctx);
132-
this->precondition = precondition_in;
133134

134135
test_david = 2;
135136
// 1: check which function is called and which step is executed

0 commit comments

Comments
 (0)