Skip to content

Commit 0098171

Browse files
authored
Feature: add para_gemm to do parallel matrix multiply (#5870)
* Feature: add para_gemm to do parallel matrix multi Refator: move math_kernel_op to module_base * fix compile * fix compile * try fix pyabacus * add gatherC for para_gemm * add test
1 parent 51b4c88 commit 0098171

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1656
-686
lines changed

python/pyabacus/CONTRIBUTING.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ Welcome to the `pyabacus` project! This document provides guidelines and instruc
88

99
<!-- toc -->
1010

11-
- [Project structure](#project-structure)
11+
- [Developer Guide](#developer-guide)
12+
- [Introduction](#introduction)
13+
- [Project Structure](#project-structure)
1214
- [Root CMake Configuration](#root-cmake-configuration)
1315
- [Module CMake Configuration](#module-cmake-configuration)
14-
- [Development Process](#development-process)
16+
- [Development Process](#development-process)
17+
- [Conclusion](#conclusion)
1518

1619
<!-- tocstop -->
1720

@@ -187,7 +190,7 @@ list(APPEND _diago
187190
${HSOLVER_PATH}/diag_const_nums.cpp
188191
${HSOLVER_PATH}/diago_iter_assist.cpp
189192
${HSOLVER_PATH}/kernels/dngvd_op.cpp
190-
${HSOLVER_PATH}/kernels/math_kernel_op.cpp
193+
${BASE_PATH}/kernels/math_kernel_op.cpp
191194
${BASE_PATH}/kernels/math_op.cpp
192195
${BASE_PATH}/module_device/device.cpp
193196
${BASE_PATH}/module_device/memory_op.cpp

python/pyabacus/src/ModuleBase/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
list(APPEND pymodule_base
22
${PROJECT_SOURCE_DIR}/src/ModuleBase/py_base_math.cpp
33
${BASE_PATH}/kernels/math_op.cpp
4+
${BASE_PATH}/kernels/math_kernel_op.cpp
45
${BASE_PATH}/module_device/memory_op.cpp
56
${BASE_PATH}/module_device/device.cpp
67
)

python/pyabacus/src/ModuleNAO/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ list(APPEND _naos
1414
${NAO_PATH}/two_center_table.cpp
1515
# dependency
1616
${ABACUS_SOURCE_DIR}/module_base/kernels/math_op.cpp
17+
${ABACUS_SOURCE_DIR}/module_base/kernels/math_kernel_op.cpp
1718
# ${ABACUS_SOURCE_DIR}/module_psi/kernels/psi_memory_op.cpp
1819
${ABACUS_SOURCE_DIR}/module_base/module_device/memory_op.cpp
1920
${ABACUS_SOURCE_DIR}/module_base/module_device/device.cpp

python/pyabacus/src/hsolver/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ list(APPEND _diago
1010

1111

1212
${HSOLVER_PATH}/kernels/dngvd_op.cpp
13-
${HSOLVER_PATH}/kernels/math_kernel_op.cpp
1413
# dependency
14+
${BASE_PATH}/kernels/math_kernel_op.cpp
1515
${BASE_PATH}/kernels/math_op.cpp
1616
${BASE_PATH}/module_device/device.cpp
1717
${BASE_PATH}/module_device/memory_op.cpp

python/pyabacus/src/hsolver/py_hsolver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <pybind11/numpy.h>
77

88
#include "module_hsolver/diago_dav_subspace.h"
9-
#include "module_hsolver/kernels/math_kernel_op.h"
9+
#include "module_base/kernels/math_kernel_op.h"
1010
#include "module_base/module_device/types.h"
1111

1212
#include "./py_diago_dav_subspace.hpp"

source/CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ list(APPEND device_srcs
3636
module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp
3737
module_basis/module_pw/kernels/pw_op.cpp
3838
module_hsolver/kernels/dngvd_op.cpp
39-
module_hsolver/kernels/math_kernel_op.cpp
4039
module_elecstate/kernels/elecstate_op.cpp
4140

4241
# module_psi/kernels/psi_memory_op.cpp
4342
# module_psi/kernels/device.cpp
4443

4544
module_base/module_device/device.cpp
4645
module_base/module_device/memory_op.cpp
46+
module_base/kernels/math_kernel_op.cpp
4747

4848
module_hamilt_pw/hamilt_pwdft/kernels/force_op.cpp
4949
module_hamilt_pw/hamilt_pwdft/kernels/stress_op.cpp
@@ -64,7 +64,6 @@ if(USE_CUDA)
6464
module_hamilt_pw/hamilt_pwdft/kernels/cuda/onsite_op.cu
6565
module_basis/module_pw/kernels/cuda/pw_op.cu
6666
module_hsolver/kernels/cuda/dngvd_op.cu
67-
module_hsolver/kernels/cuda/math_kernel_op.cu
6867
module_elecstate/kernels/cuda/elecstate_op.cu
6968

7069
# module_psi/kernels/cuda/memory_op.cu
@@ -75,6 +74,7 @@ if(USE_CUDA)
7574
module_hamilt_pw/hamilt_pwdft/kernels/cuda/wf_op.cu
7675
module_hamilt_pw/hamilt_pwdft/kernels/cuda/vnl_op.cu
7776
module_base/kernels/cuda/math_op.cu
77+
module_base/kernels/cuda/math_kernel_op.cu
7878
module_hamilt_general/module_xc/kernels/cuda/xc_functional_op.cu
7979
)
8080
endif()
@@ -89,7 +89,6 @@ if(USE_ROCM)
8989
module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu
9090
module_basis/module_pw/kernels/rocm/pw_op.hip.cu
9191
module_hsolver/kernels/rocm/dngvd_op.hip.cu
92-
module_hsolver/kernels/rocm/math_kernel_op.hip.cu
9392
module_elecstate/kernels/rocm/elecstate_op.hip.cu
9493

9594
# module_psi/kernels/rocm/memory_op.hip.cu
@@ -99,6 +98,7 @@ if(USE_ROCM)
9998
module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu
10099
module_hamilt_pw/hamilt_pwdft/kernels/rocm/wf_op.hip.cu
101100
module_hamilt_pw/hamilt_pwdft/kernels/rocm/vnl_op.hip.cu
101+
module_base/kernels/rocm/math_kernel_op.hip.cu
102102
module_base/kernels/rocm/math_op.hip.cu
103103
module_hamilt_general/module_xc/kernels/rocm/xc_functional_op.hip.cu
104104
)

source/Makefile.Objects

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,13 @@ OBJS_BASE=abfs-vector3_order.o\
146146
math_bspline.o\
147147
math_chebyshev.o\
148148
math_op.o\
149+
math_kernel_op.o\
149150
mathzone_add1.o\
150151
matrix.o\
151152
matrix3.o\
152153
memory.o\
153154
mymath.o\
155+
para_gemm.o\
154156
realarray.o\
155157
sph_bessel_recursive-d1.o\
156158
sph_bessel_recursive-d2.o\
@@ -336,7 +338,6 @@ OBJS_HSOLVER=diago_cg.o\
336338
hsolver_lcaopw.o\
337339
hsolver_pw_sdft.o\
338340
diago_iter_assist.o\
339-
math_kernel_op.o\
340341
dngvd_op.o\
341342
diag_const_nums.o\
342343
diag_hs_para.o\

source/module_base/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_library(
3737
mymath.cpp
3838
opt_CG.cpp
3939
opt_DCsrch.cpp
40+
para_gemm.cpp
4041
realarray.cpp
4142
sph_bessel_recursive-d1.cpp
4243
sph_bessel_recursive-d2.cpp

source/module_base/blas_connector.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <base/macros/macros.h>
1111
#include <cuda_runtime.h>
1212
#include "cublas_v2.h"
13-
#include "module_hsolver/kernels/math_kernel_op.h"
13+
#include "module_base/kernels/math_kernel_op.h"
1414
#include "module_base/module_device/memory_op.h"
1515

1616

@@ -668,7 +668,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec
668668
}
669669
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
670670
#ifdef __CUDA
671-
hsolver::vector_mul_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
671+
ModuleBase::vector_mul_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
672672
#endif
673673
}
674674
}
@@ -688,7 +688,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
688688
}
689689
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
690690
#ifdef __CUDA
691-
hsolver::vector_div_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
691+
ModuleBase::vector_div_vector_op<T, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, vector2);
692692
#endif
693693
}
694694
}
@@ -706,7 +706,7 @@ void vector_add_vector(const int& dim, float *result, const float *vector1, cons
706706
}
707707
else if (device_type == base_device::GpuDevice){
708708
#ifdef __CUDA
709-
hsolver::constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
709+
ModuleBase::constantvector_addORsub_constantVector_op<float, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
710710
#endif
711711
}
712712
}
@@ -724,7 +724,7 @@ void vector_add_vector(const int& dim, double *result, const double *vector1, co
724724
}
725725
else if (device_type == base_device::GpuDevice){
726726
#ifdef __CUDA
727-
hsolver::constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
727+
ModuleBase::constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
728728
#endif
729729
}
730730
}
@@ -742,7 +742,7 @@ void vector_add_vector(const int& dim, std::complex<float> *result, const std::c
742742
}
743743
else if (device_type == base_device::GpuDevice){
744744
#ifdef __CUDA
745-
hsolver::constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
745+
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<float>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
746746
#endif
747747
}
748748
}
@@ -760,7 +760,7 @@ void vector_add_vector(const int& dim, std::complex<double> *result, const std::
760760
}
761761
else if (device_type == base_device::GpuDevice){
762762
#ifdef __CUDA
763-
hsolver::constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
763+
ModuleBase::constantvector_addORsub_constantVector_op<std::complex<double>, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2);
764764
#endif
765765
}
766766
}

source/module_hsolver/kernels/cuda/math_kernel_op.cu renamed to source/module_base/kernels/cuda/math_kernel_op.cu

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "module_base/module_device/memory_op.h"
2-
#include "module_hsolver/kernels/math_kernel_op.h"
2+
#include "module_base/kernels/math_kernel_op.h"
33
#include "module_psi/psi.h"
44
#include "module_base/tool_quit.h"
55

@@ -9,7 +9,7 @@
99
#include <thrust/execution_policy.h>
1010
#include <thrust/inner_product.h>
1111

12-
namespace hsolver
12+
namespace ModuleBase
1313
{
1414
const int warp_size = 32;
1515
// const unsigned int full_mask = 0xffffffff;
@@ -24,7 +24,7 @@ template <>
2424
struct GetTypeReal<thrust::complex<double>> {
2525
using type = double; /**< The return type specialization for std::complex<double>. */
2626
};
27-
namespace hsolver {
27+
namespace ModuleBase {
2828
template <typename T>
2929
struct GetTypeThrust {
3030
using type = T;
@@ -817,6 +817,27 @@ void scal_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEV
817817
cublasErrcheck(cublasZscal(cublas_handle, N, (double2*)alpha, (double2*)X, incx));
818818
}
819819

820+
template <>
821+
void gemm_op<float, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
822+
const char& transa,
823+
const char& transb,
824+
const int& m,
825+
const int& n,
826+
const int& k,
827+
const float* alpha,
828+
const float* a,
829+
const int& lda,
830+
const float* b,
831+
const int& ldb,
832+
const float* beta,
833+
float* c,
834+
const int& ldc)
835+
{
836+
cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op");
837+
cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op");
838+
cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc));
839+
}
840+
820841
template <>
821842
void gemm_op<double, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* d,
822843
const char& transa,
@@ -1060,4 +1081,4 @@ template struct vector_div_vector_op<double, base_device::DEVICE_GPU>;
10601081
template struct matrixSetToAnother<double, base_device::DEVICE_GPU>;
10611082
template struct constantvector_addORsub_constantVector_op<double, base_device::DEVICE_GPU>;
10621083
#endif
1063-
} // namespace hsolver
1084+
} // namespace ModuleBase

0 commit comments

Comments
 (0)