Skip to content

Commit 57cecae

Browse files
committed
fix compile
1 parent 0fde07e commit 57cecae

File tree

9 files changed

+78
-32
lines changed

9 files changed

+78
-32
lines changed

source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "module_base/module_device/memory_op.h"
2-
#include "module_ModuleBase/kernels/math_kernel_op.h"
2+
#include "module_base/kernels/math_kernel_op.h"
33
#include "module_psi/psi.h"
44
#include "module_base/tool_quit.h"
55

@@ -817,6 +817,27 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
817817
cublasErrcheck(cublasZscal(cublas_handle, N, (double2*)alpha, (double2*)X, incx));
818818
}
819819

820+
template <>
821+
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
822+
const char& transa,
823+
const char& transb,
824+
const int& m,
825+
const int& n,
826+
const int& k,
827+
const float* alpha,
828+
const float* a,
829+
const int& lda,
830+
const float* b,
831+
const int& ldb,
832+
const float* beta,
833+
float* c,
834+
const int& ldc)
835+
{
836+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
837+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
838+
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
839+
}
840+
820841
template <>
821842
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
822843
const char& transa,

source/module_base/kernels/rocm/math_kernel_op.hip.cu

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "module_base/module_device/memory_op.h"
2-
#include "module_ModuleBase/kernels/math_kernel_op.h"
2+
#include "module_base/kernels/math_kernel_op.h"
33
#include "module_psi/psi.h"
44
#include "module_base/tool_quit.h"
55

@@ -735,6 +735,27 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
735735
hipblasErrcheck(hipblasZscal(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incx));
736736
}
737737

738+
template <>
739+
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
740+
const char& transa,
741+
const char& transb,
742+
const int& m,
743+
const int& n,
744+
const int& k,
745+
const float* alpha,
746+
const float* a,
747+
const int& lda,
748+
const float* b,
749+
const int& ldb,
750+
const float* beta,
751+
float* c,
752+
const int& ldc)
753+
{
754+
hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
755+
hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
756+
hipblasErrcheck(hipblasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
757+
}
758+
738759
template <>
739760
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
740761
const char& transa,

source/module_base/kernels/test/math_kernel_test.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,9 @@ TEST_F(TestModuleHsolverMathKernel, zdot_real_op_gpu)
375375
resize_memory_op()(psi_R_dev, psi_R.size());
376376
synchronize_memory_op()(psi_L_dev, psi_L.data(), psi_L.size());
377377
synchronize_memory_op()(psi_R_dev, psi_R.data(), psi_R.size());
378-
hsolver::createGpuBlasHandle();
378+
ModuleBase::createGpuBlasHandle();
379379
double result = zdot_real_gpu_op()(gpu_ctx, dim, psi_L_dev, psi_R_dev, false);
380-
hsolver::destoryBLAShandle();
380+
ModuleBase::destoryBLAShandle();
381381
EXPECT_LT(fabs(result - expected_result), 1e-12);
382382
delete_memory_op()(psi_L_dev);
383383
delete_memory_op()(psi_R_dev);
@@ -537,9 +537,9 @@ TEST_F(TestModuleHsolverMathKernel, axpy_op_gpu)
537537
synchronize_memory_op()(Y_axpy_dev, Y_axpy.data(), Y_axpy.size());
538538

539539
// run
540-
hsolver::createGpuBlasHandle();
540+
ModuleBase::createGpuBlasHandle();
541541
axpy_op_gpu()(gpu_ctx, dim, &alpha_axpy, X_axpy_dev, 1, Y_axpy_dev, 1);
542-
hsolver::destoryBLAShandle();
542+
ModuleBase::destoryBLAShandle();
543543

544544
// syn the output data in GPU to CPU
545545
synchronize_memory_op_gpu()(Y_axpy.data(), Y_axpy_dev, Y_axpy.size());
@@ -566,9 +566,9 @@ TEST_F(TestModuleHsolverMathKernel, scal_op_gpu)
566566
synchronize_memory_op()(X_scal_dev, X_scal.data(), X_scal.size());
567567

568568
// run
569-
hsolver::createGpuBlasHandle();
569+
ModuleBase::createGpuBlasHandle();
570570
scal_op_gpu()(gpu_ctx, dim, &alpha_scal, X_scal_dev, 1);
571-
hsolver::destoryBLAShandle();
571+
ModuleBase::destoryBLAShandle();
572572

573573
// syn the output data in GPU to CPU
574574
synchronize_memory_op_gpu()(X_scal.data(), X_scal_dev, X_scal.size());
@@ -599,9 +599,9 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu)
599599
synchronize_memory_op()(Y_gemv_dev, Y_gemv.data(), Y_gemv.size());
600600

601601
// run
602-
hsolver::createGpuBlasHandle();
602+
ModuleBase::createGpuBlasHandle();
603603
gemv_op_gpu()(gpu_ctx, 'C', 2, 3, &ModuleBase::ONE, A_gemv_dev, 2, X_gemv_dev, 1, &ModuleBase::ONE, Y_gemv_dev, 1);
604-
hsolver::destoryBLAShandle();
604+
ModuleBase::destoryBLAShandle();
605605
// syn the output data in GPU to CPU
606606
synchronize_memory_op_gpu()(Y_gemv.data(), Y_gemv_dev, Y_gemv.size());
607607

source/module_base/para_gemm.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,11 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
189189
}
190190
}
191191
else
192-
#endif
193192
{
194193
T real_beta = row_rank == 0 ? beta : 0;
194+
#else
195+
T real_beta = beta;
196+
#endif
195197
ModuleBase::gemm_op<T, Device>()(ctx,
196198
'C',
197199
'N',
@@ -206,8 +208,10 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
206208
&real_beta,
207209
C_global,
208210
LDC_global);
211+
#ifdef __MPI
209212
Parallel_Common::reduce_dev<T, Device>(C_global, size_C_global, row_world);
210213
}
214+
#endif
211215
}
212216

213217
template class PGemmCN<double, base_device::DEVICE_CPU>;

source/module_esolver/esolver_ks_pw.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
7373
#if ((defined __CUDA) || (defined __ROCM))
7474
if (this->device == base_device::GpuDevice)
7575
{
76-
hsolver::createGpuBlasHandle();
76+
ModuleBase::createGpuBlasHandle();
7777
hsolver::createGpuSolverHandle();
7878
container::kernels::createGpuBlasHandle();
7979
container::kernels::createGpuSolverHandle();
@@ -101,7 +101,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
101101
if (this->device == base_device::GpuDevice)
102102
{
103103
#if defined(__CUDA) || defined(__ROCM)
104-
hsolver::destoryBLAShandle();
104+
ModuleBase::destoryBLAShandle();
105105
hsolver::destroyGpuSolverHandle();
106106
container::kernels::destroyGpuBlasHandle();
107107
container::kernels::destroyGpuSolverHandle();

source/module_hsolver/diago_dav_subspace.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,11 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
354354
if (this->device == base_device::GpuDevice)
355355
{
356356
syncmem_var_h2d_op()(this->d_precondition, pre.data(), this->dim);
357-
vector_div_vector_op<T, Device>()(this->ctx,
358-
this->dim,
359-
psi_iter + (nbase + m) * this->dim,
360-
psi_iter + (nbase + m) * this->dim,
361-
this->d_precondition);
357+
ModuleBase::vector_div_vector_op<T, Device>()(this->ctx,
358+
this->dim,
359+
psi_iter + (nbase + m) * this->dim,
360+
psi_iter + (nbase + m) * this->dim,
361+
this->d_precondition);
362362
}
363363
else
364364
#endif

source/module_hsolver/diago_david.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,11 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
416416
Real* e_temp_gpu = nullptr;
417417
resmem_var_op()(e_temp_gpu, nbase);
418418
syncmem_var_h2d_op()(e_temp_gpu, e_temp_cpu.data(), nbase);
419-
vector_mul_vector_op<T, Device>()(this->ctx,
420-
nbase,
421-
vc_ev_vector + m * nbase,
422-
vc_ev_vector + m * nbase,
423-
e_temp_gpu);
419+
ModuleBase::vector_mul_vector_op<T, Device>()(this->ctx,
420+
nbase,
421+
vc_ev_vector + m * nbase,
422+
vc_ev_vector + m * nbase,
423+
e_temp_gpu);
424424
delmem_var_op()(e_temp_gpu);
425425
#endif
426426
}
@@ -468,11 +468,11 @@ void DiagoDavid<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
468468
if (this->device == base_device::GpuDevice)
469469
{
470470
#if defined(__CUDA) || defined(__ROCM)
471-
vector_div_vector_op<T, Device>()(this->ctx,
472-
dim,
473-
basis + dim*(nbase + m),
474-
basis + dim*(nbase + m),
475-
this->d_precondition);
471+
ModuleBase::vector_div_vector_op<T, Device>()(this->ctx,
472+
dim,
473+
basis + dim * (nbase + m),
474+
basis + dim * (nbase + m),
475+
this->d_precondition);
476476
#endif
477477
}
478478
else

source/module_hsolver/kernels/test/math_dngvd_test.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,13 @@ TEST_F(TestModuleHsolverMathDngvd, transpose_gpu)
144144
synchronize_memory_op_C2G_Z()(device_transpose, transpose.data(), transpose.size());
145145

146146
// run
147-
hsolver::createGpuBlasHandle();
147+
ModuleBase::createGpuBlasHandle();
148148
ModuleBase::matrixTranspose_op<std::complex<double>, base_device::DEVICE_GPU>()(gpu_ctx,
149149
2,
150150
3,
151151
device_transpose,
152152
device_transpose);
153-
hsolver::destoryBLAShandle();
153+
ModuleBase::destoryBLAShandle();
154154

155155
// copy transpose data from GPU to CPU
156156
std::vector<std::complex<double>> transpose_result = {

source/module_hsolver/kernels/test/perf_math_kernel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture {
114114
resize_memory_op_double()(test_dvector_a_gpu, dim_vector);
115115
synchronize_memory_op_double()(test_dvector_a_gpu, test_dvector_a, dim_vector);
116116

117-
hsolver::createGpuBlasHandle();
117+
ModuleBase::createGpuBlasHandle();
118118

119119

120120
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
@@ -125,7 +125,7 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture {
125125
delete[] result_zvector;
126126
delete[] test_dvector_a;
127127
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
128-
hsolver::destoryBLAShandle();
128+
ModuleBase::destoryBLAShandle();
129129
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
130130
}
131131

0 commit comments

Comments
 (0)