Skip to content

Commit 3540ea2

Browse files
committed
Remove ctx in axpy_op
1 parent 17b75b9 commit 3540ea2

File tree

8 files changed

+14
-22
lines changed

8 files changed

+14
-22
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -665,8 +665,7 @@ void constantvector_addORsub_constantVector_op<T, base_device::DEVICE_GPU>::oper
665665
}
666666

667667
template <>
668-
void axpy_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
669-
const int& N,
668+
void axpy_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
670669
const double* alpha,
671670
const double* X,
672671
const int& incX,
@@ -677,8 +676,7 @@ void axpy_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
677676
}
678677

679678
template <>
680-
void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
681-
const int& N,
679+
void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int& N,
682680
const std::complex<float>* alpha,
683681
const std::complex<float>* X,
684682
const int& incX,
@@ -689,8 +687,7 @@ void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
689687
}
690688

691689
template <>
692-
void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
693-
const int& N,
690+
void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int& N,
694691
const std::complex<double>* alpha,
695692
const std::complex<double>* X,
696693
const int& incX,

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,7 @@ struct gemv_op<T, base_device::DEVICE_CPU>
251251
template <typename T>
252252
struct axpy_op<T, base_device::DEVICE_CPU>
253253
{
254-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
255-
const int& dim,
254+
void operator()(const int& dim,
256255
const T* alpha,
257256
const T* X,
258257
const int& incX,

source/module_base/kernels/math_kernel_op.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ template <typename T, typename Device> struct axpy_op {
190190
/// @brief Y = alpha * X + Y
191191
///
192192
/// Input Parameters
193-
/// \param d : the type of computing device
194193
/// \param N : array size
195194
/// \param alpha : input constant alpha
196195
/// \param X : input array X
@@ -200,7 +199,7 @@ template <typename T, typename Device> struct axpy_op {
200199
///
201200
/// Output Parameters
202201
/// \param Y : output array Y
203-
void operator()(const Device *d, const int &N, const T *alpha, const T *X,
202+
void operator()(const int &N, const T *alpha, const T *X,
204203
const int &incX, T *Y, const int &incY);
205204
};
206205

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,8 +589,7 @@ void constantvector_addORsub_constantVector_op<T, base_device::DEVICE_GPU>::oper
589589
}
590590

591591
template <>
592-
void axpy_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
593-
const int& N,
592+
void axpy_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
594593
const double* alpha,
595594
const double* X,
596595
const int& incX,
@@ -601,8 +600,7 @@ void axpy_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
601600
}
602601

603602
template <>
604-
void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
605-
const int& N,
603+
void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int& N,
606604
const std::complex<float>* alpha,
607605
const std::complex<float>* X,
608606
const int& incX,
@@ -613,8 +611,7 @@ void axpy_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
613611
}
614612

615613
template <>
616-
void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
617-
const int& N,
614+
void axpy_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int& N,
618615
const std::complex<double>* alpha,
619616
const std::complex<double>* X,
620617
const int& incX,

source/module_base/kernels/test/math_kernel_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ TEST_F(TestModuleHsolverMathKernel, constantvector_addORsub_constantVector_op_cp
312312

313313
TEST_F(TestModuleHsolverMathKernel, axpy_op_cpu)
314314
{
315-
axpy_op_cpu()(cpu_ctx, dim, &alpha_axpy, X_axpy.data(), 1, Y_axpy.data(), 1);
315+
axpy_op_cpu()(dim, &alpha_axpy, X_axpy.data(), 1, Y_axpy.data(), 1);
316316
for (int i = 0; i < input.size(); i++)
317317
{
318318
EXPECT_LT(fabs(Y_axpy[i].imag() - output_axpy_op[i].imag()), 1e-8);
@@ -536,7 +536,7 @@ TEST_F(TestModuleHsolverMathKernel, axpy_op_gpu)
536536

537537
// run
538538
ModuleBase::createGpuBlasHandle();
539-
axpy_op_gpu()(gpu_ctx, dim, &alpha_axpy, X_axpy_dev, 1, Y_axpy_dev, 1);
539+
axpy_op_gpu()(dim, &alpha_axpy, X_axpy_dev, 1, Y_axpy_dev, 1);
540540
ModuleBase::destoryBLAShandle();
541541

542542
// syn the output data in GPU to CPU

source/module_hsolver/diago_cg.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
386386
{
387387
pcg[i] -= norma * pphi_m[i];
388388
}*/
389-
ModuleBase::axpy_op<T, Device>()(ctx_, this->n_basis_, &znorma, phi_m.data<T>(), 1, cg.data<T>(), 1);
389+
ModuleBase::axpy_op<T, Device>()(this->n_basis_, &znorma, phi_m.data<T>(), 1, cg.data<T>(), 1);
390390
}
391391
}
392392

source/module_hsolver/kernels/test/perf_math_kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_constantvector_addORsub_const
193193

194194
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_axpy_op_cpu)(benchmark::State& state) {
195195
for (auto _ : state) {
196-
axpy_op_cpu()(cpu_ctx, dim_vector, &zconstant_a, test_zvector_a, 1 ,test_zvector_b, 1);
196+
axpy_op_cpu()(dim_vector, &zconstant_a, test_zvector_a, 1 ,test_zvector_b, 1);
197197
}
198198
}
199199

@@ -262,7 +262,7 @@ BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_constantvector_addORsub_const
262262

263263
BENCHMARK_DEFINE_F(PerfModuleHsolverMathKernel, BM_axpy_op_gpu)(benchmark::State& state) {
264264
for (auto _ : state) {
265-
axpy_op_gpu()(gpu_ctx, dim_vector, &zconstant_a, test_zvector_a_gpu, 1 ,test_zvector_b_gpu, 1);
265+
axpy_op_gpu()(dim_vector, &zconstant_a, test_zvector_a_gpu, 1 ,test_zvector_b_gpu, 1);
266266
}
267267
}
268268

source/module_hsolver/para_linear_transform.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ void PLinearTransform<T, Device>::act(const T alpha, const T* A, const T* U, con
141141
}
142142
// sum all the results
143143
T one = 1.0;
144-
ModuleBase::axpy_op<T, Device>()(ctx, ncolB * LDA, &one, B_tmp, 1, B, 1);
144+
ModuleBase::axpy_op<T, Device>()(ncolB * LDA, &one, B_tmp, 1, B, 1);
145145
}
146146
delmem_dev_op()(U_tmp);
147147
delmem_dev_op()(B_tmp);

0 commit comments

Comments
 (0)