Skip to content

Commit 2ff323b

Browse files
committed
Remove ctx in scal_op
1 parent 9d98e3a commit 2ff323b

File tree

6 files changed

+10
-18
lines changed

6 files changed

+10
-18
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
790790
}
791791

792792
template <>
793-
void scal_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
794-
const int& N,
793+
void scal_op<float, base_device::DEVICE_GPU>::operator()(const int& N,
795794
const std::complex<float>* alpha,
796795
std::complex<float>* X,
797796
const int& incx)
@@ -800,8 +799,7 @@ void scal_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVI
800799
}
801800

802801
template <>
803-
void scal_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
804-
const int& N,
802+
void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
805803
const std::complex<double>* alpha,
806804
std::complex<double>* X,
807805
const int& incx)

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,7 @@ struct constantvector_addORsub_constantVector_op<T, base_device::DEVICE_CPU>
220220
template <typename FPTYPE>
221221
struct scal_op<FPTYPE, base_device::DEVICE_CPU>
222222
{
223-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
224-
const int& N,
223+
void operator()(const int& N,
225224
const std::complex<FPTYPE>* alpha,
226225
std::complex<FPTYPE>* X,
227226
const int& incx)

source/module_base/kernels/math_kernel_op.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ template <typename T, typename Device> struct dot_real_op {
9090
/// while enabling planewave parallization strategy.
9191
///
9292
/// Input Parameters
93-
/// \param d : the type of computing device
9493
/// \param dim : array size
9594
/// \param psi_L : input array A
9695
/// \param psi_R : input array B
@@ -108,7 +107,6 @@ template <typename T, typename Device> struct vector_div_constant_op {
108107
/// @brief result[i] = vector[i] / constant
109108
///
110109
/// Input Parameters
111-
/// \param d : the type of computing device
112110
/// \param dim : array size
113111
/// \param vector : input array
114112
/// \param constant : input constant
@@ -124,15 +122,14 @@ template <typename FPTYPE, typename Device> struct scal_op {
124122
/// @brief x = alpha * x
125123
///
126124
/// Input Parameters
127-
/// \param d : the type of computing device
128125
/// \param N : array size
129126
/// \param alpha : input constant
130127
/// \param X : input array
131128
/// \param incx : computing strip of array X
132129
///
133130
/// Output Parameters
134131
/// \param X : output array
135-
void operator()(const Device *d, const int &N,
132+
void operator()(const int &N,
136133
const std::complex<FPTYPE> *alpha, std::complex<FPTYPE> *X,
137134
const int &incx);
138135
};

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -708,8 +708,7 @@ void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const ba
708708
}
709709

710710
template <>
711-
void scal_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
712-
const int& N,
711+
void scal_op<float, base_device::DEVICE_GPU>::operator()(const int& N,
713712
const std::complex<float>* alpha,
714713
std::complex<float>* X,
715714
const int& incx)
@@ -718,8 +717,7 @@ void scal_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVI
718717
}
719718

720719
template <>
721-
void scal_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
722-
const int& N,
720+
void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
723721
const std::complex<double>* alpha,
724722
std::complex<double>* X,
725723
const int& incx)

source/module_base/kernels/test/math_kernel_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ TEST_F(TestModuleHsolverMathKernel, axpy_op_cpu)
323323

324324
TEST_F(TestModuleHsolverMathKernel, scal_op_cpu)
325325
{
326-
scal_op_cpu()(cpu_ctx, dim, &alpha_scal, X_scal.data(), 1);
326+
scal_op_cpu()(dim, &alpha_scal, X_scal.data(), 1);
327327
for (int i = 0; i < input.size(); i++)
328328
{
329329
EXPECT_LT(fabs(X_scal[i].imag() - output_scal_op[i].imag()), 1e-8);
@@ -567,7 +567,7 @@ TEST_F(TestModuleHsolverMathKernel, scal_op_gpu)
567567

568568
// run
569569
ModuleBase::createGpuBlasHandle();
570-
scal_op_gpu()(gpu_ctx, dim, &alpha_scal, X_scal_dev, 1);
570+
scal_op_gpu()(dim, &alpha_scal, X_scal_dev, 1);
571571
ModuleBase::destoryBLAShandle();
572572

573573
// syn the output data in GPU to CPU

source/module_hsolver/kernels/test/perf_math_kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_axpy_op_cpu)(benchmark::State
199199

200200
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_scal_op_cpu)(benchmark::State& state) {
201201
for (auto _ : state) {
202-
scal_op_cpu()(cpu_ctx, dim_vector, &zconstant_a, test_zvector_a, 1);
202+
scal_op_cpu()(dim_vector, &zconstant_a, test_zvector_a, 1);
203203
}
204204
}
205205

@@ -268,7 +268,7 @@ BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_axpy_op_gpu)(benchmark::State
268268

269269
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_scal_op_gpu)(benchmark::State& state) {
270270
for (auto _ : state) {
271-
scal_op_gpu()(gpu_ctx, dim_vector, &zconstant_a, test_zvector_a_gpu, 1);
271+
scal_op_gpu()(dim_vector, &zconstant_a, test_zvector_a_gpu, 1);
272272
}
273273
}
274274

0 commit comments

Comments
 (0)