Skip to content

Commit 52fc451

Browse files
committed
Remove ctx in gemm_op
1 parent 9db93de commit 52fc451

File tree

20 files changed

+63
-126
lines changed

20 files changed

+63
-126
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -793,8 +793,7 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
793793
}
794794

795795
template <>
796-
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
797-
const char& transa,
796+
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const char& transa,
798797
const char& transb,
799798
const int& m,
800799
const int& n,
@@ -814,8 +813,7 @@ void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVI
814813
}
815814

816815
template <>
817-
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
818-
const char& transa,
816+
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const char& transa,
819817
const char& transb,
820818
const int& m,
821819
const int& n,
@@ -834,8 +832,7 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
834832
cublasErrcheck(cublasDgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
835833
}
836834
template <>
837-
void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
838-
const char& transa,
835+
void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const char& transa,
839836
const char& transb,
840837
const int& m,
841838
const int& n,
@@ -855,8 +852,7 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
855852
}
856853

857854
template <>
858-
void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
859-
const char& transa,
855+
void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const char& transa,
860856
const char& transb,
861857
const int& m,
862858
const int& n,

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ struct axpy_op<T, base_device::DEVICE_CPU>
264264
template <typename T>
265265
struct gemm_op<T, base_device::DEVICE_CPU>
266266
{
267-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
268-
const char& transa,
267+
void operator()(const char& transa,
269268
const char& transb,
270269
const int& m,
271270
const int& n,
@@ -287,8 +286,7 @@ struct gemm_op<T, base_device::DEVICE_CPU>
287286
template <typename T>
288287
struct gemm_op_mt<T, base_device::DEVICE_CPU>
289288
{
290-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
291-
const char& transa,
289+
void operator()(const char& transa,
292290
const char& transb,
293291
const int& m,
294292
const int& n,

source/module_base/kernels/math_kernel_op.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ template <typename T, typename Device> struct gemm_op {
233233
/// @brief C = alpha * op(A) * op(B) + beta * C
234234
///
235235
/// Input Parameters
236-
/// \param d : the type of computing device
237236
/// \param transa : whether to transpose matrix A
238237
/// \param transb : whether to transpose matrix B
239238
/// \param m : first dimension of matrix mulplication
@@ -250,7 +249,7 @@ template <typename T, typename Device> struct gemm_op {
250249
///
251250
/// Output Parameters
252251
/// \param c : output matrix C
253-
void operator()(const Device *d, const char &transa, const char &transb,
252+
void operator()(const char &transa, const char &transb,
254253
const int &m, const int &n, const int &k, const T *alpha,
255254
const T *a, const int &lda, const T *b, const int &ldb,
256255
const T *beta, T *c, const int &ldc);
@@ -262,7 +261,6 @@ template <typename T, typename Device> struct gemm_op_mt {
262261
/// @brief C = alpha * op(A) * op(B) + beta * C
263262
///
264263
/// Input Parameters
265-
/// \param d : the type of computing device
266264
/// \param transa : whether to transpose matrix A
267265
/// \param transb : whether to transpose matrix B
268266
/// \param m : first dimension of matrix mulplication
@@ -279,7 +277,7 @@ template <typename T, typename Device> struct gemm_op_mt {
279277
///
280278
/// Output Parameters
281279
/// \param c : output matrix C
282-
void operator()(const Device *d, const char &transa, const char &transb,
280+
void operator()(const char &transa, const char &transb,
283281
const int &m, const int &n, const int &k, const T *alpha,
284282
const T *a, const int &lda, const T *b, const int &ldb,
285283
const T *beta, T *c, const int &ldc);

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,7 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const int& N,
711711
}
712712

713713
template <>
714-
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
715-
const char& transa,
714+
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const char& transa,
716715
const char& transb,
717716
const int& m,
718717
const int& n,
@@ -732,8 +731,7 @@ void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVI
732731
}
733732

734733
template <>
735-
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
736-
const char& transa,
734+
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const char& transa,
737735
const char& transb,
738736
const int& m,
739737
const int& n,
@@ -753,8 +751,7 @@ void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
753751
}
754752

755753
template <>
756-
void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
757-
const char& transa,
754+
void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const char& transa,
758755
const char& transb,
759756
const int& m,
760757
const int& n,
@@ -774,8 +771,7 @@ void gemm_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const bas
774771
}
775772

776773
template <>
777-
void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
778-
const char& transa,
774+
void gemm_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const char& transa,
779775
const char& transb,
780776
const int& m,
781777
const int& n,

source/module_base/para_gemm.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ void PGemmCN<T, Device>::multiply_single(const T alpha, const T* A, const T* B,
137137
#else
138138
T real_beta = beta;
139139
#endif
140-
ModuleBase::gemm_op<T, Device>()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC);
140+
ModuleBase::gemm_op<T, Device>()('C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC);
141141
#ifdef __MPI
142142
if (this->row_nproc > 1)
143143
{
@@ -201,8 +201,7 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
201201
T* C_start = C_local + shift;
202202
if (col_rank == ip)
203203
{
204-
ModuleBase::gemm_op<T, Device>()(ctx,
205-
'C',
204+
ModuleBase::gemm_op<T, Device>()('C',
206205
'N',
207206
ncolA,
208207
ncolB,
@@ -224,8 +223,7 @@ void PGemmCN<T, Device>::multiply_col(const T alpha, const T* A, const T* B, con
224223
MPI_Status status;
225224
Parallel_Common::recv_dev<T, Device>(Atmp_device, size, ip, 0, col_world, &status, B_tmp.data());
226225
MPI_Wait(&requests[ip], &status);
227-
ModuleBase::gemm_op<T, Device>()(ctx,
228-
'C',
226+
ModuleBase::gemm_op<T, Device>()('C',
229227
'N',
230228
m,
231229
ncolB,
@@ -321,8 +319,7 @@ void PGemmCN<T, Device>::multiply_row(const T alpha, const T* A, const T* B, con
321319
T* C_start = C + shift;
322320
if (col_rank == ip)
323321
{
324-
ModuleBase::gemm_op<T, Device>()(ctx,
325-
'C',
322+
ModuleBase::gemm_op<T, Device>()('C',
326323
'N',
327324
ncolA,
328325
ncolB,
@@ -344,8 +341,7 @@ void PGemmCN<T, Device>::multiply_row(const T alpha, const T* A, const T* B, con
344341
MPI_Status status;
345342
Parallel_Common::recv_dev<T, Device>(Btmp_device, size, ip, 0, col_world, &status, B_tmp.data());
346343
MPI_Wait(&requests[ip], &status);
347-
ModuleBase::gemm_op<T, Device>()(ctx,
348-
'C',
344+
ModuleBase::gemm_op<T, Device>()('C',
349345
'N',
350346
ncolA,
351347
m,

source/module_base/test_parallel/test_para_gemm.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,7 @@ class PgemmTest : public ::testing::Test
141141
const base_device::DEVICE_CPU* ctx = {};
142142
char transC = 'C';
143143
char transN = 'N';
144-
ModuleBase::gemm_op<T, base_device::DEVICE_CPU>()(ctx,
145-
transC,
144+
ModuleBase::gemm_op<T, base_device::DEVICE_CPU>()(transC,
146145
transN,
147146
ncolA_global,
148147
ncolB_global,

source/module_elecstate/elecstate_pw.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
316316
}
317317
else
318318
{
319-
gemm_op()(this->ctx,
320-
transa,
319+
gemm_op()(transa,
321320
transb,
322321
this->ppcell->nkb,
323322
nbands,
@@ -367,8 +366,7 @@ void ElecStatePW<T, Device>::cal_becsum(const psi::Psi<T, Device>& psi)
367366

368367
char transa = 'C';
369368
char transb = 'N';
370-
gemm_op()(this->ctx,
371-
transa,
369+
gemm_op()(transa,
372370
transb,
373371
atom->ncpp.nh,
374372
atom->ncpp.nh,
@@ -517,8 +515,7 @@ void ElecStatePW<T, Device>::addusdens_g(const Real* becsum, T** rhog)
517515
// sum over atoms
518516
char transa = 'N';
519517
char transb = 'T';
520-
gemm_op()(this->ctx,
521-
transa,
518+
gemm_op()(transa,
522519
transb,
523520
npw,
524521
nij,

source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ void spinconstrain::SpinConstrain<std::complex<double>>::calculate_delta_hcc(std
8585
#if ((defined __CUDA) || (defined __ROCM))
8686
base_device::DEVICE_GPU* ctx = {};
8787
ModuleBase::gemm_op<std::complex<double>, base_device::DEVICE_GPU>()(
88-
ctx,
8988
transa,
9089
transb,
9190
nbands,
@@ -109,7 +108,6 @@ void spinconstrain::SpinConstrain<std::complex<double>>::calculate_delta_hcc(std
109108
{
110109
base_device::DEVICE_CPU* ctx = {};
111110
ModuleBase::gemm_op<std::complex<double>, base_device::DEVICE_CPU>()(
112-
ctx,
113111
transa,
114112
transb,
115113
nbands,

source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_becp(const int& ik,
269269
const char transb = 'N';
270270
const int npm_npol = npm * npol;
271271
const int index0 = nbd0 * npol * nkb;
272-
gemm_op()(this->ctx,
273-
transa,
272+
gemm_op()(transa,
274273
transb,
275274
this->nkb,
276275
npm_npol,
@@ -433,8 +432,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_s(const int& ik,
433432
// 2.b calculate dbecp = dbecp_noevc * psi
434433
const char transa = 'C';
435434
const char transb = 'N';
436-
gemm_op()(this->ctx,
437-
transa,
435+
gemm_op()(transa,
438436
transb,
439437
this->nkb,
440438
npm_npol,
@@ -587,8 +585,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::cal_dbecp_f(const int& ik,
587585
// do gemm to get dbecp and revert the ppcell_vkb for next ipol
588586
const char transa = 'C';
589587
const char transb = 'N';
590-
gemm_op()(this->ctx,
591-
transa,
588+
gemm_op()(transa,
592589
transb,
593590
this->nkb,
594591
npm_npol,

source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
274274
}
275275
else
276276
{
277-
gemm_op()(this->ctx,
278-
transa,
277+
gemm_op()(transa,
279278
transb,
280279
this->ppcell->nkb,
281280
nbands,
@@ -328,8 +327,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
328327
for (int ia = 0; ia < atoms->na; ia++)
329328
{
330329
const int iat = ucell->itia2iat(it, ia);
331-
gemm_op()(this->ctx,
332-
transa,
330+
gemm_op()(transa,
333331
transb,
334332
nh,
335333
nbands,
@@ -364,8 +362,7 @@ void HamiltPW<T, Device>::sPsi(const T* psi_in, // psi
364362
}
365363
else
366364
{
367-
gemm_op()(this->ctx,
368-
transa,
365+
gemm_op()(transa,
369366
transb,
370367
npw,
371368
nbands,

0 commit comments

Comments
 (0)