44#include " module_base/module_device/types.h"
55#include " module_base/module_device/memory_op.h"
66#include " module_hsolver/kernels/math_kernel_op.h"
7+
8+ // / @brief Preconditioner Function Library
9+ // / Users can add other types of operation than the following ones at one's need.
710namespace hsolver
811{
912 template <typename T>
@@ -19,24 +22,21 @@ namespace hsolver
1922 // / @brief Transform vectors
2023 namespace fvec
2124 {
22- // / @brief To be called in the iterative eigensolver.
23- // / Users can add other types of operation than the following ones at one's need.
24- // / fixed parameters: object vector, eigenvalue, leading dimension, number of vectors
25-
2625 // /---------------------------------------------------------------------------------------------
2726 // / type 1: directly divide each vector by the precondition vector
2827 // /---------------------------------------------------------------------------------------------
2928 template <typename T, typename Device = base_device::DEVICE_CPU>
3029 void div_prevec (T* ptr, const size_t & dim, const size_t & nvec,
3130 const Real<T>* const pre )
3231 {
32+ Device* ctx = {};
3333 for (int m = 0 ; m < nvec; m++)
3434 {
3535 T* const ptr_m = ptr + m * dim;
36- vector_div_vector_op<T, Device>()({} , dim, ptr_m, ptr_m, pre );
36+ vector_div_vector_op<T, Device>()(ctx , dim, ptr_m, ptr_m, pre );
3737 }
3838 }
39- // / calling intereface in the eigensolver
39+ // / Intereface to be called in the eigensolver
4040 template <typename T>
4141 using Div = std::function<void (T*, const size_t &, const size_t &)>;
4242 // Kernel function full of dependence
@@ -54,6 +54,8 @@ namespace hsolver
5454 using syncmem_var_h2d_op = base_device::memory::synchronize_memory_op<Real<T>, Device, base_device::DEVICE_CPU>;
5555 std::vector<Real<T>> pre_trans (dim, 0.0 );
5656 const auto device = base_device::get_device_type<Device>({});
57+ Device* ctx = {};
58+ base_device::DEVICE_CPU* cpu_ctx = {};
5759
5860 for (int m = 0 ; m < nvec; m++)
5961 {
@@ -63,27 +65,28 @@ namespace hsolver
6365 if (device == base_device::GpuDevice)
6466 {
6567 assert (d_pre);
66- syncmem_var_h2d_op ()({}, {} , d_pre, pre_trans.data (), dim);
67- vector_div_vector_op<T, Device>()({} , dim, ptr_m, ptr_m, d_pre);
68+ syncmem_var_h2d_op ()(ctx, cpu_ctx , d_pre, pre_trans.data (), dim);
69+ vector_div_vector_op<T, Device>()(ctx , dim, ptr_m, ptr_m, d_pre);
6870 }
6971 else
7072#endif
7173 {
72- vector_div_vector_op<T, Device>()({} , dim, ptr_m, ptr_m, pre_trans.data ());
74+ vector_div_vector_op<T, Device>()(ctx , dim, ptr_m, ptr_m, pre_trans.data ());
7375 }
7476 }
7577 }
76- // / calling intereface in the eigensolver
78+ // / Intereface to be called in the eigensolver
7779 template <typename T>
7880 using DivTransMinusEig = std::function<void (T*, const Real<T>*, const size_t &, const size_t &)>;
79- // Kernel function full of dependence
81+ // / Kernel function full of dependence
8082 template <typename T, typename Device = base_device::DEVICE_CPU>
8183 using DivTransMinusEigKernel = std::function<decltype (div_trans_prevec_minus_eigen<T, Device>)>;
8284 }
8385
8486 // / @brief A operator-like class of precondition function
8587 // / to encapsulate the pre-allocation of memory on different devices before starting the iterative eigensolver.
86- // / One can pass the operatr() function of this class, or other custom lambdas/functions to eigensolvers.
88+ // / One can use `.get()` interface to get the function to be called by the eigensovler,
89+ // / or pass a custom lambdas/function to replace the one returned by `.get()`.
8790 template <typename T, typename Device = base_device::DEVICE_CPU, typename Kernel_t = fvec::DivKernel<T, Device>>
8891 struct PreOP
8992 {
0 commit comments