Skip to content

Commit 9db93de

Browse files
committed
Remove ctx in axpy_op
1 parent 3540ea2 commit 9db93de

File tree

13 files changed

+24
-46
lines changed

13 files changed

+24
-46
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
718718
}
719719

720720
template <>
721-
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
722-
const char& trans,
721+
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const char& trans,
723722
const int& m,
724723
const int& n,
725724
const double* alpha,
@@ -736,8 +735,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
736735
}
737736

738737
template <>
739-
void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
740-
const char& trans,
738+
void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const char& trans,
741739
const int& m,
742740
const int& n,
743741
const std::complex<float>* alpha_in,
@@ -756,8 +754,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
756754
}
757755

758756
template <>
759-
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
760-
const char& trans,
757+
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const char& trans,
761758
const int& m,
762759
const int& n,
763760
const std::complex<double>* alpha_in,

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ struct scal_op<FPTYPE, base_device::DEVICE_CPU>
231231
template <typename T>
232232
struct gemv_op<T, base_device::DEVICE_CPU>
233233
{
234-
void operator()(const base_device::DEVICE_CPU* d,
235-
const char& trans,
234+
void operator()(const char& trans,
236235
const int& m,
237236
const int& n,
238237
const T* alpha,

source/module_base/kernels/math_kernel_op.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ template <typename T, typename Device> struct gemv_op {
208208
/// @brief y = alpha * op(A) * x + beta * y
209209
///
210210
/// Input Parameters
211-
/// \param d : the type of computing device
212211
/// \param trans : whether to transpose A
213212
/// \param m : first dimension of matrix
214213
/// \param n : second dimension of matrix
@@ -223,7 +222,7 @@ template <typename T, typename Device> struct gemv_op {
223222
///
224223
/// Output Parameters
225224
/// \param Y : output array Y
226-
void operator()(const Device *d, const char &trans, const int &m,
225+
void operator()(const char &trans, const int &m,
227226
const int &n, const T *alpha, const T *A, const int &lda,
228227
const T *X, const int &incx, const T *beta, T *Y,
229228
const int &incy);

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,7 @@ hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char
642642
}
643643

644644
template <>
645-
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
646-
const char& trans,
645+
void gemv_op<double, base_device::DEVICE_GPU>::operator()(const char& trans,
647646
const int& m,
648647
const int& n,
649648
const double* alpha,
@@ -660,8 +659,7 @@ void gemv_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
660659
}
661660

662661
template <>
663-
void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
664-
const char& trans,
662+
void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const char& trans,
665663
const int& m,
666664
const int& n,
667665
const std::complex<float>* alpha,
@@ -678,8 +676,7 @@ void gemv_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
678676
}
679677

680678
template <>
681-
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
682-
const char& trans,
679+
void gemv_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const char& trans,
683680
const int& m,
684681
const int& n,
685682
const std::complex<double>* alpha,

source/module_base/kernels/test/math_kernel_test.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,7 @@ TEST_F(TestModuleHsolverMathKernel, scal_op_cpu)
332332

333333
TEST_F(TestModuleHsolverMathKernel, gemv_op_cpu)
334334
{
335-
gemv_op_cpu()(cpu_ctx,
336-
'C',
335+
gemv_op_cpu()('C',
337336
2,
338337
3,
339338
&ModuleBase::ONE,
@@ -598,7 +597,7 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu)
598597

599598
// run
600599
ModuleBase::createGpuBlasHandle();
601-
gemv_op_gpu()(gpu_ctx, 'C', 2, 3, &ModuleBase::ONE, A_gemv_dev, 2, X_gemv_dev, 1, &ModuleBase::ONE, Y_gemv_dev, 1);
600+
gemv_op_gpu()('C', 2, 3, &ModuleBase::ONE, A_gemv_dev, 2, X_gemv_dev, 1, &ModuleBase::ONE, Y_gemv_dev, 1);
602601
ModuleBase::destoryBLAShandle();
603602
// syn the output data in GPU to CPU
604603
synchronize_memory_op_gpu()(Y_gemv.data(), Y_gemv_dev, Y_gemv.size());

source/module_elecstate/elecstate_pw.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
302302
if (nbands == 1)
303303
{
304304
int inc = 1;
305-
gemv_op()(this->ctx,
306-
transa,
305+
gemv_op()(transa,
307306
npw,
308307
this->ppcell->nkb,
309308
&one,

source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
260260
if (nbands == 1)
261261
{
262262
int inc = 1;
263-
gemv_op()(this->ctx,
264-
transa,
263+
gemv_op()(transa,
265264
npw,
266265
this->ppcell->nkb,
267266
&one,
@@ -351,8 +350,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
351350
if (nbands == 1)
352351
{
353352
const int inc = 1;
354-
gemv_op()(this->ctx,
355-
transa,
353+
gemv_op()(transa,
356354
npw,
357355
this->ppcell->nkb,
358356
&one,

source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ void Nonlocal<OperatorPW<T, Device>>::add_nonlocal_pp(T *hpsi_in, const T *becp,
170170
// denghui replace 2022-10-20
171171
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
172172
gemv_op()(
173-
this->ctx,
174173
transa,
175174
this->npw,
176175
this->ppcell->nkb,
@@ -246,7 +245,6 @@ void Nonlocal<OperatorPW<T, Device>>::act(
246245
// denghui replace 2022-10-20
247246
// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
248247
gemv_op()(
249-
this->ctx,
250248
transa,
251249
this->npw,
252250
nkb,

source/module_hamilt_pw/hamilt_stodft/sto_che.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ REAL vTMv(const REAL* v, const REAL* M, const int n)
5151
const REAL zero = 0;
5252
REAL* y = nullptr;
5353
base_device::memory::resize_memory_op<REAL, Device>()(y, n);
54-
ModuleBase::gemv_op<REAL, Device>()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc);
54+
ModuleBase::gemv_op<REAL, Device>()(normal, n, n, &one, M, n, v, inc, &zero, y, inc);
5555
REAL result = 0;
5656
REAL* dot_device = nullptr;
5757
base_device::memory::resize_memory_op<REAL, Device>()(dot_device, 1);

source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ void Stochastic_Iter<T, Device>::calTnchi_ik(const int& ik, Stochastic_WF<T, Dev
769769
T* coef_real = nullptr;
770770
resmem_complex_op()(coef_real, N);
771771
castmem_d2z_op()(coef_real, p_che->coef_real, p_che->norder);
772-
gemv_op()(this->ctx, transa, M, N, &one, stowf.chiallorder[ik].get_pointer(), LDA, coef_real, inc, &zero, out, inc);
772+
gemv_op()(transa, M, N, &one, stowf.chiallorder[ik].get_pointer(), LDA, coef_real, inc, &zero, out, inc);
773773
// zgemv_(&transa, &M, &N, &one, stowf.chiallorder[ik].get_pointer(), &LDA, coef_real, &inc, &zero, out, &inc);
774774
delmem_complex_op()(coef_real);
775775
}

0 commit comments

Comments
 (0)