Skip to content

Commit 8f7d319

Browse files
Perf: Optimize Davidson by fusing operators, offloading CPU computation to GPU, and reducing memory transfers (#6493)
* Perf: Optimize Diago_DavSubspace with GPU operators by adding and fusing custom kernels. Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd. * Perf: reduce memory allocation and copy in Diago_DavSubspace::diag_zhegvx Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd. * Perf: Replace loop-based 2D copy and memset with memcpy_2d_op, memset_2d_op Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd. * Perf: use warp reduce instead of shared memory for better efficiency Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd. * Fix compilation error Signed-off-by:Tianxiang Wang<[email protected]>, Contributed under MetaX Integrated Circuits (Shanghai) Co., Ltd.
1 parent ec3f0a6 commit 8f7d319

File tree

15 files changed

+768
-190
lines changed

15 files changed

+768
-190
lines changed

source/source_base/kernels/cuda/math_kernel_op.cu

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ __global__ void matrix_copy_kernel(const int n1, const int n2, const T* A, const
133133
}
134134
}
135135

136+
template <typename T, typename Real>
137+
__global__ void matrix_multiply_vector_kernel(const int m, const int n, T *a, const int lda, const Real *b, const Real alpha, T *c, const int ldc){
138+
int row = blockIdx.x * blockDim.x + threadIdx.x;
139+
int col = blockIdx.y * blockDim.y + threadIdx.y;
140+
if (col >= n || row >= m) return;
141+
c[col * ldc + row] = a[col * lda + row] * b[col] * alpha;
142+
}
143+
136144
cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
137145
{
138146
if (trans == 'N')
@@ -147,7 +155,7 @@ cublasOperation_t judge_trans_op(bool is_complex, const char& trans, const char*
147155
{
148156
return CUBLAS_OP_C;
149157
}
150-
else
158+
else
151159
{
152160
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
153161
}
@@ -438,10 +446,44 @@ void matrixCopy<std::complex<double>, base_device::DEVICE_GPU>::operator()(const
438446
cudaCheckOnDebug();
439447
}
440448

449+
template <>
450+
void matrix_mul_vector_op<double, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
451+
double *a, const int &lda, const double *b, const double alpha, double *c, const int &ldc){
452+
dim3 thread(16, 16, 1);
453+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
454+
matrix_multiply_vector_kernel<double, double> <<<block, thread >>>(m, n, a, lda,
455+
b, alpha, c, ldc);
456+
cudaCheckOnDebug();
457+
}
458+
459+
template <>
460+
void matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
461+
std::complex<float> *a, const int &lda, const float *b, const float alpha, std::complex<float> *c, const int &ldc){
462+
dim3 thread(16, 16, 1);
463+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
464+
matrix_multiply_vector_kernel<thrust::complex<float>, float> <<<block, thread >>>(m, n, reinterpret_cast<thrust::complex<float>*>(a), lda,
465+
b, alpha, reinterpret_cast<thrust::complex<float>*>(c), ldc);
466+
cudaCheckOnDebug();
467+
}
468+
469+
template <>
470+
void matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
471+
std::complex<double> *a, const int &lda, const double *b, const double alpha, std::complex<double> *c, const int &ldc)
472+
{
473+
dim3 thread(16, 16, 1);
474+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
475+
matrix_multiply_vector_kernel<thrust::complex<double>, double> <<<block, thread >>>(m, n, reinterpret_cast<thrust::complex<double>*>(a), lda,
476+
b, alpha, reinterpret_cast<thrust::complex<double>*>(c), ldc);
477+
cudaCheckOnDebug();
478+
}
441479

442480
// Explicitly instantiate functors for the types of functor registered.
443481

444482
template struct matrixCopy<std::complex<float>, base_device::DEVICE_GPU>;
445483
template struct matrixCopy<double, base_device::DEVICE_GPU>;
446484
template struct matrixCopy<std::complex<double>, base_device::DEVICE_GPU>;
485+
486+
template struct matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>;
487+
template struct matrix_mul_vector_op<double, base_device::DEVICE_GPU>;
488+
template struct matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>;
447489
} // namespace ModuleBase

source/source_base/kernels/math_kernel_op.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,35 @@ struct matrixCopy<T, base_device::DEVICE_CPU>
119119
}
120120
};
121121

122+
template <typename T>
123+
struct matrix_mul_vector_op<T, base_device::DEVICE_CPU> {
124+
using Real = typename GetTypeReal<T>::type;
125+
void operator()(const int& m, const int &n,
126+
T *a,
127+
const int &lda,
128+
const Real *b,
129+
const Real alpha,
130+
T *c,
131+
const int &ldc){
132+
#ifdef _OPENMP
133+
#pragma omp parallel for collapse(2) schedule(static, 8192 / sizeof(T))
134+
#endif
135+
for (int j = 0; j < n; j++){
136+
for (int i = 0; i < m; i++){
137+
c[j * ldc + i] = a[j * lda + i] * b[j] * alpha;
138+
}
139+
}
140+
141+
}
142+
};
143+
122144
template struct gemv_op<std::complex<float>, base_device::DEVICE_CPU>;
123145
template struct gemv_op<float, base_device::DEVICE_CPU>;
124146
template struct gemm_op<std::complex<float>, base_device::DEVICE_CPU>;
125147
template struct gemm_op<float, base_device::DEVICE_CPU>;
126148
template struct matrixTranspose_op<std::complex<float>, base_device::DEVICE_CPU>;
127149
template struct matrixCopy<std::complex<float>, base_device::DEVICE_CPU>;
150+
template struct matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_CPU>;
128151

129152
template struct gemv_op<std::complex<double>, base_device::DEVICE_CPU>;
130153
template struct gemv_op<double, base_device::DEVICE_CPU>;
@@ -133,6 +156,8 @@ template struct gemm_op<double, base_device::DEVICE_CPU>;
133156
template struct matrixTranspose_op<std::complex<double>, base_device::DEVICE_CPU>;
134157
template struct matrixCopy<double, base_device::DEVICE_CPU>;
135158
template struct matrixCopy<std::complex<double>, base_device::DEVICE_CPU>;
159+
template struct matrix_mul_vector_op<double, base_device::DEVICE_CPU>;
160+
template struct matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_CPU>;
136161

137162
#ifdef __LCAO
138163
template struct matrixTranspose_op<double, base_device::DEVICE_CPU>;

source/source_base/kernels/math_kernel_op.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ template <typename T, typename Device> struct vector_div_constant_op {
104104
///
105105
/// Input Parameters
106106
/// \param dim : array size
107-
/// \param vector : input array
107+
/// \param vector : input array
108108
/// \param constant : input constant
109109
///
110110
/// Output Parameters
@@ -298,6 +298,31 @@ template <typename T, typename Device> struct matrixCopy {
298298
void operator()(const int& n1, const int& n2, const T* A, const int& LDA, T* B, const int& LDB);
299299
};
300300

301+
template <typename T, typename Device>
302+
struct matrix_mul_vector_op {
303+
using Real = typename GetTypeReal<T>::type;
304+
/// @brief a * b * beta by each column
305+
///
306+
/// Input Parameters
307+
/// \param m : row number
308+
/// \param n : column number
309+
/// \param a : input matrix
310+
/// \param lda : leading dimension of matrix a
311+
/// \param b : input vector
312+
/// \param alpha : factor
313+
/// \param ldc : leading dimension of matrix c
314+
///
315+
/// Output Parameters
316+
/// \param c : output matrix
317+
void operator()(const int &m, const int &n,
318+
T *a,
319+
const int &lda,
320+
const Real *b,
321+
const Real alpha,
322+
T *c,
323+
const int &ldc);
324+
};
325+
301326
template <typename T, typename Device>
302327
struct apply_eigenvalues_op {
303328
using Real = typename GetTypeReal<T>::type;
@@ -314,7 +339,7 @@ struct precondition_op {
314339
T* psi_iter,
315340
const int& nbase,
316341
const int& notconv,
317-
const Real* precondition,
342+
const Real* precondition,
318343
const Real* eigenvalues);
319344
};
320345

@@ -393,6 +418,17 @@ template <typename T> struct matrixCopy<T, base_device::DEVICE_GPU> {
393418
const int& LDB);
394419
};
395420

421+
template <typename T> struct matrix_mul_vector_op<T, base_device::DEVICE_GPU> {
422+
using Real = typename GetTypeReal<T>::type;
423+
void operator()(const int &m, const int &n,
424+
T *a,
425+
const int &lda,
426+
const Real *b,
427+
const Real alpha,
428+
T *c,
429+
const int &ldc);
430+
};
431+
396432
void createGpuBlasHandle();
397433
void destoryBLAShandle();
398434

source/source_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ __launch_bounds__(1024) __global__
145145
}
146146
}
147147

148+
template <typename T, typename Real>
149+
__launch_bounds__(1024) __global__
150+
void matrix_multiply_vector_kernel(const int m, const int n, T *a, const int lda, const Real *b, const Real alpha, T *c, const int ldc){
151+
int row = blockIdx.x * blockDim.x + threadIdx.x;
152+
int col = blockIdx.y * blockDim.y + threadIdx.y;
153+
if (col >= n || row >= m) return;
154+
c[col * ldc + row] = a[col * lda + row] * b[col] * alpha;
155+
}
156+
148157
hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char* name)
149158
{
150159
if (trans == 'N')
@@ -159,7 +168,7 @@ hipblasOperation_t judge_trans_op(bool is_complex, const char& trans, const char
159168
{
160169
return HIPBLAS_OP_C;
161170
}
162-
else
171+
else
163172
{
164173
ModuleBase::WARNING_QUIT(name, std::string("Unknown trans type ") + trans + std::string(" !"));
165174
}
@@ -437,7 +446,38 @@ void matrixCopy<std::complex<double>, base_device::DEVICE_GPU>::operator()(const
437446
hipCheckOnDebug();
438447
}
439448

449+
template <>
450+
void matrix_mul_vector_op<double, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
451+
double *a, const int &lda, const double *b, const double alpha, double *c, const int &ldc){
452+
dim3 thread(16, 16, 1);
453+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
454+
hipLaunchKernelGGL(HIP_KERNEL_NAME(matrix_multiply_vector_kernel<double, double>), dim3(block, thread),
455+
m, n, a, lda, b, alpha, c, ldc);
456+
hipCheckOnDebug();
457+
}
440458

459+
template <>
460+
void matrix_mul_vector_op<std::complex<float>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
461+
std::complex<float> *a, const int &lda, const float *b, const float alpha, std::complex<float> *c, const int &ldc){
462+
dim3 thread(16, 16, 1);
463+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
464+
hipLaunchKernelGGL(HIP_KERNEL_NAME(matrix_multiply_vector_kernel<thrust::complex<float>, float>), dim3(block, thread),
465+
m, n, reinterpret_cast<thrust::complex<float>*>(a), lda,
466+
b, alpha, reinterpret_cast<thrust::complex<float>*>(c), ldc);
467+
hipCheckOnDebug();
468+
}
469+
470+
template <>
471+
void matrix_mul_vector_op<std::complex<double>, base_device::DEVICE_GPU>::operator()(const int &m, const int &n,
472+
std::complex<double> *a, const int &lda, const double *b, const double alpha, std::complex<double> *c, const int &ldc)
473+
{
474+
dim3 thread(16, 16, 1);
475+
dim3 block((m + thread.x - 1) / thread.x, (n + thread.y - 1) / thread.y, 1);
476+
hipLaunchKernelGGL(HIP_KERNEL_NAME(matrix_multiply_vector_kernel<thrust::complex<double>, double>), dim3(block, thread),
477+
m, n, reinterpret_cast<thrust::complex<double>*>(a), lda,
478+
b, alpha, reinterpret_cast<thrust::complex<double>*>(c), ldc);
479+
hipCheckOnDebug();
480+
}
441481

442482
// Explicitly instantiate functors for the types of functor registered.
443483
template struct matrixCopy<double, base_device::DEVICE_GPU>;

source/source_base/module_device/cuda/memory_op.cu

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ void set_memory_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE* arr,
8585
cudaErrcheck(cudaMemset(arr, var, sizeof(FPTYPE) * size));
8686
}
8787

88+
template <typename FPTYPE>
89+
void set_memory_2d_op<FPTYPE, base_device::DEVICE_GPU>::operator()(FPTYPE* arr,
90+
const size_t pitch,
91+
const int var,
92+
const size_t width,
93+
const size_t height)
94+
{
95+
cudaErrcheck(cudaMemset2D(arr, sizeof(FPTYPE) * pitch , var, sizeof(FPTYPE) * width, height));
96+
}
97+
8898
template <typename FPTYPE>
8999
void synchronize_memory_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>::operator()(
90100
FPTYPE* arr_out,
@@ -112,6 +122,42 @@ void synchronize_memory_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_
112122
cudaErrcheck(cudaMemcpy(arr_out, arr_in, sizeof(FPTYPE) * size, cudaMemcpyDeviceToDevice));
113123
}
114124

125+
template <typename FPTYPE>
126+
void synchronize_memory_2d_op<FPTYPE, base_device::DEVICE_CPU, base_device::DEVICE_GPU>::operator()(
127+
FPTYPE* arr_out,
128+
const size_t dpitch,
129+
const FPTYPE* arr_in,
130+
const size_t spitch,
131+
const size_t width,
132+
const size_t height)
133+
{
134+
cudaErrcheck(cudaMemcpy2D(arr_out, dpitch * sizeof(FPTYPE), arr_in, spitch * sizeof(FPTYPE), width * sizeof(FPTYPE), height, cudaMemcpyDeviceToHost));
135+
}
136+
137+
template <typename FPTYPE>
138+
void synchronize_memory_2d_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_CPU>::operator()(
139+
FPTYPE* arr_out,
140+
const size_t dpitch,
141+
const FPTYPE* arr_in,
142+
const size_t spitch,
143+
const size_t width,
144+
const size_t height)
145+
{
146+
cudaErrcheck(cudaMemcpy2D(arr_out, dpitch * sizeof(FPTYPE), arr_in, spitch * sizeof(FPTYPE), width * sizeof(FPTYPE), height, cudaMemcpyHostToDevice));
147+
}
148+
149+
template <typename FPTYPE>
150+
void synchronize_memory_2d_op<FPTYPE, base_device::DEVICE_GPU, base_device::DEVICE_GPU>::operator()(
151+
FPTYPE* arr_out,
152+
const size_t dpitch,
153+
const FPTYPE* arr_in,
154+
const size_t spitch,
155+
const size_t width,
156+
const size_t height)
157+
{
158+
cudaErrcheck(cudaMemcpy2D(arr_out, dpitch * sizeof(FPTYPE), arr_in, spitch * sizeof(FPTYPE), width * sizeof(FPTYPE), height, cudaMemcpyDeviceToDevice));
159+
}
160+
115161
template <typename FPTYPE_out, typename FPTYPE_in>
116162
struct cast_memory_op<FPTYPE_out, FPTYPE_in, base_device::DEVICE_GPU, base_device::DEVICE_GPU>
117163
{
@@ -196,6 +242,12 @@ template struct set_memory_op<double, base_device::DEVICE_GPU>;
196242
template struct set_memory_op<std::complex<float>, base_device::DEVICE_GPU>;
197243
template struct set_memory_op<std::complex<double>, base_device::DEVICE_GPU>;
198244

245+
template struct set_memory_2d_op<int, base_device::DEVICE_GPU>;
246+
template struct set_memory_2d_op<float, base_device::DEVICE_GPU>;
247+
template struct set_memory_2d_op<double, base_device::DEVICE_GPU>;
248+
template struct set_memory_2d_op<std::complex<float>, base_device::DEVICE_GPU>;
249+
template struct set_memory_2d_op<std::complex<double>, base_device::DEVICE_GPU>;
250+
199251
template struct synchronize_memory_op<int, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
200252
template struct synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
201253
template struct synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
@@ -212,6 +264,22 @@ template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_
212264
template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
213265
template struct synchronize_memory_op<std::complex<double>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
214266

267+
template struct synchronize_memory_2d_op<int, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
268+
template struct synchronize_memory_2d_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
269+
template struct synchronize_memory_2d_op<int, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
270+
template struct synchronize_memory_2d_op<float, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
271+
template struct synchronize_memory_2d_op<float, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
272+
template struct synchronize_memory_2d_op<float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
273+
template struct synchronize_memory_2d_op<double, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
274+
template struct synchronize_memory_2d_op<double, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
275+
template struct synchronize_memory_2d_op<double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
276+
template struct synchronize_memory_2d_op<std::complex<float>, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
277+
template struct synchronize_memory_2d_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
278+
template struct synchronize_memory_2d_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
279+
template struct synchronize_memory_2d_op<std::complex<double>, base_device::DEVICE_CPU, base_device::DEVICE_GPU>;
280+
template struct synchronize_memory_2d_op<std::complex<double>, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
281+
template struct synchronize_memory_2d_op<std::complex<double>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
282+
215283
template struct cast_memory_op<float, float, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
216284
template struct cast_memory_op<double, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;
217285
template struct cast_memory_op<float, double, base_device::DEVICE_GPU, base_device::DEVICE_GPU>;

0 commit comments

Comments
 (0)