Skip to content

Commit 2c47138

Browse files
authored
[Refactor] Remove all ctx in math_kernel_op (#5934)
* Remove ctx in dot_real_op * Fix dot_complex_wrapper bug * Remove ctx in vector_div_constant_op * Remove ctx in scal_op * Remove ctx in vector_mul_vector_op * Remove ctx in vector_div_vector_op * Fix compile bug * Remove ctx in constantVector_addOrsub_constantVector * Remove ctx in axpy_op * Remove ctx in axpy_op * Remove ctx in gemm_op * Remove ctx in matrixTranspose_op * Remove ctx in matrixSetToAnother * Fix bot compilation error 1 * Fix bot compilation eror 2
1 parent 8fa625c commit 2c47138

28 files changed

+254
-428
lines changed

source/module_base/blas_connector.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec
782782
}
783783
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
784784
#ifdef __CUDA
785-
ModuleBase::vector_mul_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
785+
ModuleBase::vector_mul_vector_op<T, base_device::DEVICE_GPU>()(dim, result, vector1, vector2);
786786
#endif
787787
}
788788
}
@@ -802,7 +802,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
802802
}
803803
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
804804
#ifdef __CUDA
805-
ModuleBase::vector_div_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
805+
ModuleBase::vector_div_vector_op<T, base_device::DEVICE_GPU>()(dim, result, vector1, vector2);
806806
#endif
807807
}
808808
}
@@ -820,7 +820,7 @@ void vector_add_vector(const int& dim, float *result, const float *vector1, cons
820820
}
821821
else if (device_type == base_device::GpuDevice){
822822
#ifdef __CUDA
823-
ModuleBase::constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
823+
ModuleBase::constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2);
824824
#endif
825825
}
826826
}
@@ -838,7 +838,7 @@ void vector_add_vector(const int& dim, double *result, const double *vector1, co
838838
}
839839
else if (device_type == base_device::GpuDevice){
840840
#ifdef __CUDA
841-
ModuleBase::constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
841+
ModuleBase::constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2);
842842
#endif
843843
}
844844
}
@@ -856,7 +856,7 @@ void vector_add_vector(const int& dim, std::complex<float> *result, const std::c
856856
}
857857
else if (device_type == base_device::GpuDevice){
858858
#ifdef __CUDA
859-
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
859+
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2);
860860
#endif
861861
}
862862
}
@@ -874,7 +874,7 @@ void vector_add_vector(const int& dim, std::complex<double> *result, const std::
874874
}
875875
else if (device_type == base_device::GpuDevice){
876876
#ifdef __CUDA
877-
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
877+
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>()(dim, result, vector1, constant1, vector2, constant2);
878878
#endif
879879
}
880880
}

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 38 additions & 73 deletions
Large diffs are not rendered by default.

source/module_base/kernels/math_kernel_op.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ struct calc_grad_with_block_op<T, base_device::DEVICE_CPU>
110110
template <typename FPTYPE>
111111
struct dot_real_op<FPTYPE, base_device::DEVICE_CPU>
112112
{
113-
FPTYPE operator()(const base_device::DEVICE_CPU* d,
114-
const int& dim,
113+
FPTYPE operator()(const int& dim,
115114
const FPTYPE* psi_L,
116115
const FPTYPE* psi_R,
117116
const bool reduce)
@@ -129,8 +128,7 @@ struct dot_real_op<FPTYPE, base_device::DEVICE_CPU>
129128
template <typename FPTYPE>
130129
struct dot_real_op<std::complex<FPTYPE>, base_device::DEVICE_CPU>
131130
{
132-
FPTYPE operator()(const base_device::DEVICE_CPU* d,
133-
const int& dim,
131+
FPTYPE operator()(const int& dim,
134132
const std::complex<FPTYPE>* psi_L,
135133
const std::complex<FPTYPE>* psi_R,
136134
const bool reduce)
@@ -153,7 +151,7 @@ template <typename T>
153151
struct vector_div_constant_op<T, base_device::DEVICE_CPU>
154152
{
155153
using Real = typename GetTypeReal<T>::type;
156-
void operator()(const base_device::DEVICE_CPU* d, const int dim, T* result, const T* vector, const Real constant)
154+
void operator()(const int dim, T* result, const T* vector, const Real constant)
157155
{
158156
#ifdef _OPENMP
159157
#pragma omp parallel for schedule(static, 4096 / sizeof(Real))
@@ -169,7 +167,7 @@ template <typename T>
169167
struct vector_mul_vector_op<T, base_device::DEVICE_CPU>
170168
{
171169
using Real = typename GetTypeReal<T>::type;
172-
void operator()(const base_device::DEVICE_CPU* d, const int& dim, T* result, const T* vector1, const Real* vector2)
170+
void operator()(const int& dim, T* result, const T* vector1, const Real* vector2)
173171
{
174172
#ifdef _OPENMP
175173
#pragma omp parallel for schedule(static, 4096 / sizeof(Real))
@@ -185,7 +183,7 @@ template <typename T>
185183
struct vector_div_vector_op<T, base_device::DEVICE_CPU>
186184
{
187185
using Real = typename GetTypeReal<T>::type;
188-
void operator()(const base_device::DEVICE_CPU* d, const int& dim, T* result, const T* vector1, const Real* vector2)
186+
void operator()(const int& dim, T* result, const T* vector1, const Real* vector2)
189187
{
190188
#ifdef _OPENMP
191189
#pragma omp parallel for schedule(static, 4096 / sizeof(Real))
@@ -201,8 +199,7 @@ template <typename T>
201199
struct constantvector_addORsub_constantVector_op<T, base_device::DEVICE_CPU>
202200
{
203201
using Real = typename GetTypeReal<T>::type;
204-
void operator()(const base_device::DEVICE_CPU* d,
205-
const int& dim,
202+
void operator()(const int& dim,
206203
T* result,
207204
const T* vector1,
208205
const Real constant1,
@@ -222,8 +219,7 @@ struct constantvector_addORsub_constantVector_op<T, base_device::DEVICE_CPU>
222219
template <typename FPTYPE>
223220
struct scal_op<FPTYPE, base_device::DEVICE_CPU>
224221
{
225-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
226-
const int& N,
222+
void operator()(const int& N,
227223
const std::complex<FPTYPE>* alpha,
228224
std::complex<FPTYPE>* X,
229225
const int& incx)
@@ -235,8 +231,7 @@ struct scal_op<FPTYPE, base_device::DEVICE_CPU>
235231
template <typename T>
236232
struct gemv_op<T, base_device::DEVICE_CPU>
237233
{
238-
void operator()(const base_device::DEVICE_CPU* d,
239-
const char& trans,
234+
void operator()(const char& trans,
240235
const int& m,
241236
const int& n,
242237
const T* alpha,
@@ -255,8 +250,7 @@ struct gemv_op<T, base_device::DEVICE_CPU>
255250
template <typename T>
256251
struct axpy_op<T, base_device::DEVICE_CPU>
257252
{
258-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
259-
const int& dim,
253+
void operator()(const int& dim,
260254
const T* alpha,
261255
const T* X,
262256
const int& incX,
@@ -270,8 +264,7 @@ struct axpy_op<T, base_device::DEVICE_CPU>
270264
template <typename T>
271265
struct gemm_op<T, base_device::DEVICE_CPU>
272266
{
273-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
274-
const char& transa,
267+
void operator()(const char& transa,
275268
const char& transb,
276269
const int& m,
277270
const int& n,
@@ -293,8 +286,7 @@ struct gemm_op<T, base_device::DEVICE_CPU>
293286
template <typename T>
294287
struct gemm_op_mt<T, base_device::DEVICE_CPU>
295288
{
296-
void operator()(const base_device::DEVICE_CPU* /*ctx*/,
297-
const char& transa,
289+
void operator()(const char& transa,
298290
const char& transb,
299291
const int& m,
300292
const int& n,
@@ -316,8 +308,7 @@ struct gemm_op_mt<T, base_device::DEVICE_CPU>
316308
template <typename T>
317309
struct matrixTranspose_op<T, base_device::DEVICE_CPU>
318310
{
319-
void operator()(const base_device::DEVICE_CPU* d,
320-
const int& row,
311+
void operator()(const int& row,
321312
const int& col,
322313
const T* input_matrix,
323314
T* output_matrix)
@@ -348,7 +339,7 @@ struct matrixTranspose_op<T, base_device::DEVICE_CPU>
348339
template <typename T>
349340
struct matrixSetToAnother<T, base_device::DEVICE_CPU>
350341
{
351-
void operator()(const base_device::DEVICE_CPU* d, const int& n, const T* A, const int& LDA, T* B, const int& LDB)
342+
void operator()(const int& n, const T* A, const int& LDA, T* B, const int& LDB)
352343
{
353344
#ifdef _OPENMP
354345
#pragma omp parallel for collapse(2) schedule(static, 8192 / sizeof(T))

0 commit comments

Comments
 (0)