diff --git a/python/pyabacus/CONTRIBUTING.md b/python/pyabacus/CONTRIBUTING.md index fbd23ad9ff..b5d7728eae 100644 --- a/python/pyabacus/CONTRIBUTING.md +++ b/python/pyabacus/CONTRIBUTING.md @@ -8,10 +8,13 @@ Welcome to the `pyabacus` project! This document provides guidelines and instruc -- [Project structure](#project-structure) +- [Developer Guide](#developer-guide) + - [Introduction](#introduction) + - [Project Structure](#project-structure) - [Root CMake Configuration](#root-cmake-configuration) - [Module CMake Configuration](#module-cmake-configuration) -- [Development Process](#development-process) + - [Development Process](#development-process) + - [Conclusion](#conclusion) @@ -187,7 +190,7 @@ list(APPEND _diago ${HSOLVER_PATH}/diag_const_nums.cpp ${HSOLVER_PATH}/diago_iter_assist.cpp ${HSOLVER_PATH}/kernels/dngvd_op.cpp - ${HSOLVER_PATH}/kernels/math_kernel_op.cpp + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/kernels/math_op.cpp ${BASE_PATH}/module_device/device.cpp ${BASE_PATH}/module_device/memory_op.cpp diff --git a/python/pyabacus/src/ModuleBase/CMakeLists.txt b/python/pyabacus/src/ModuleBase/CMakeLists.txt index 7ce5fb5e3b..1c2d9a728b 100644 --- a/python/pyabacus/src/ModuleBase/CMakeLists.txt +++ b/python/pyabacus/src/ModuleBase/CMakeLists.txt @@ -1,6 +1,7 @@ list(APPEND pymodule_base ${PROJECT_SOURCE_DIR}/src/ModuleBase/py_base_math.cpp ${BASE_PATH}/kernels/math_op.cpp + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/module_device/memory_op.cpp ${BASE_PATH}/module_device/device.cpp ) diff --git a/python/pyabacus/src/ModuleNAO/CMakeLists.txt b/python/pyabacus/src/ModuleNAO/CMakeLists.txt index c5eb016903..5e86604adc 100644 --- a/python/pyabacus/src/ModuleNAO/CMakeLists.txt +++ b/python/pyabacus/src/ModuleNAO/CMakeLists.txt @@ -14,6 +14,7 @@ list(APPEND _naos ${NAO_PATH}/two_center_table.cpp # dependency ${ABACUS_SOURCE_DIR}/module_base/kernels/math_op.cpp + ${ABACUS_SOURCE_DIR}/module_base/kernels/math_kernel_op.cpp # ${ABACUS_SOURCE_DIR}/module_psi/kernels/psi_memory_op.cpp ${ABACUS_SOURCE_DIR}/module_base/module_device/memory_op.cpp ${ABACUS_SOURCE_DIR}/module_base/module_device/device.cpp diff --git a/python/pyabacus/src/hsolver/CMakeLists.txt b/python/pyabacus/src/hsolver/CMakeLists.txt index f0f04f97a7..4bd0153b48 100644 --- a/python/pyabacus/src/hsolver/CMakeLists.txt +++ b/python/pyabacus/src/hsolver/CMakeLists.txt @@ -10,8 +10,8 @@ list(APPEND _diago ${HSOLVER_PATH}/kernels/dngvd_op.cpp - ${HSOLVER_PATH}/kernels/math_kernel_op.cpp # dependency + ${BASE_PATH}/kernels/math_kernel_op.cpp ${BASE_PATH}/kernels/math_op.cpp ${BASE_PATH}/module_device/device.cpp ${BASE_PATH}/module_device/memory_op.cpp diff --git a/python/pyabacus/src/hsolver/py_hsolver.cpp b/python/pyabacus/src/hsolver/py_hsolver.cpp index e791fe9f09..3c4d1c66c4 100644 --- a/python/pyabacus/src/hsolver/py_hsolver.cpp +++ b/python/pyabacus/src/hsolver/py_hsolver.cpp @@ -6,7 +6,7 @@ #include #include "module_hsolver/diago_dav_subspace.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_device/types.h" #include "./py_diago_dav_subspace.hpp" diff --git a/source/CMakeLists.txt b/source/CMakeLists.txt index 1f4d4a8370..769138b096 100644 --- a/source/CMakeLists.txt +++ b/source/CMakeLists.txt @@ -36,7 +36,6 @@ list(APPEND device_srcs module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp module_basis/module_pw/kernels/pw_op.cpp module_hsolver/kernels/dngvd_op.cpp - module_hsolver/kernels/math_kernel_op.cpp module_elecstate/kernels/elecstate_op.cpp # module_psi/kernels/psi_memory_op.cpp @@ -44,6 +43,7 @@ list(APPEND device_srcs module_base/module_device/device.cpp module_base/module_device/memory_op.cpp + module_base/kernels/math_kernel_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/force_op.cpp module_hamilt_pw/hamilt_pwdft/kernels/stress_op.cpp @@ -64,7 +64,6 @@ if(USE_CUDA) module_hamilt_pw/hamilt_pwdft/kernels/cuda/onsite_op.cu module_basis/module_pw/kernels/cuda/pw_op.cu module_hsolver/kernels/cuda/dngvd_op.cu - module_hsolver/kernels/cuda/math_kernel_op.cu module_elecstate/kernels/cuda/elecstate_op.cu # module_psi/kernels/cuda/memory_op.cu @@ -75,6 +74,7 @@ if(USE_CUDA) module_hamilt_pw/hamilt_pwdft/kernels/cuda/wf_op.cu module_hamilt_pw/hamilt_pwdft/kernels/cuda/vnl_op.cu module_base/kernels/cuda/math_op.cu + module_base/kernels/cuda/math_kernel_op.cu module_hamilt_general/module_xc/kernels/cuda/xc_functional_op.cu ) endif() @@ -89,7 +89,6 @@ if(USE_ROCM) module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu module_basis/module_pw/kernels/rocm/pw_op.hip.cu module_hsolver/kernels/rocm/dngvd_op.hip.cu - module_hsolver/kernels/rocm/math_kernel_op.hip.cu module_elecstate/kernels/rocm/elecstate_op.hip.cu # module_psi/kernels/rocm/memory_op.hip.cu @@ -99,6 +98,7 @@ if(USE_ROCM) module_hamilt_pw/hamilt_pwdft/kernels/rocm/stress_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/wf_op.hip.cu module_hamilt_pw/hamilt_pwdft/kernels/rocm/vnl_op.hip.cu + module_base/kernels/rocm/math_kernel_op.hip.cu module_base/kernels/rocm/math_op.hip.cu module_hamilt_general/module_xc/kernels/rocm/xc_functional_op.hip.cu ) diff --git a/source/Makefile.Objects b/source/Makefile.Objects index 5d01dd1839..ad13d75976 100644 --- a/source/Makefile.Objects +++ b/source/Makefile.Objects @@ -146,11 +146,13 @@ OBJS_BASE=abfs-vector3_order.o\ math_bspline.o\ math_chebyshev.o\ math_op.o\ + math_kernel_op.o\ mathzone_add1.o\ matrix.o\ matrix3.o\ memory.o\ mymath.o\ + para_gemm.o\ realarray.o\ sph_bessel_recursive-d1.o\ sph_bessel_recursive-d2.o\ @@ -336,7 +338,6 @@ OBJS_HSOLVER=diago_cg.o\ hsolver_lcaopw.o\ hsolver_pw_sdft.o\ diago_iter_assist.o\ - math_kernel_op.o\ dngvd_op.o\ diag_const_nums.o\ diag_hs_para.o\ diff --git a/source/module_base/CMakeLists.txt b/source/module_base/CMakeLists.txt index 38c466a2c1..ecbdedcf6a 100644 --- a/source/module_base/CMakeLists.txt +++ b/source/module_base/CMakeLists.txt @@ -37,6 +37,7 @@ add_library( mymath.cpp opt_CG.cpp opt_DCsrch.cpp + para_gemm.cpp realarray.cpp sph_bessel_recursive-d1.cpp sph_bessel_recursive-d2.cpp diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 14fb76e2ed..b422969ac5 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -10,7 +10,7 @@ #include #include #include "cublas_v2.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_device/memory_op.h" @@ -668,7 +668,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - hsolver::vector_mul_vector_op()(gpu_ctx, dim, result, vector1, vector2); + ModuleBase::vector_mul_vector_op()(gpu_ctx, dim, result, vector1, vector2); #endif } } @@ -688,7 +688,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec } else if (device_type == base_device::AbacusDevice_t::GpuDevice){ #ifdef __CUDA - hsolver::vector_div_vector_op()(gpu_ctx, dim, result, vector1, vector2); + ModuleBase::vector_div_vector_op()(gpu_ctx, dim, result, vector1, vector2); #endif } } @@ -706,7 +706,7 @@ void vector_add_vector(const int& dim, float *result, const float *vector1, cons } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -724,7 +724,7 @@ void vector_add_vector(const int& dim, double *result, const double *vector1, co } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -742,7 +742,7 @@ void vector_add_vector(const int& dim, std::complex *result, const std::c } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } @@ -760,7 +760,7 @@ void vector_add_vector(const int& dim, std::complex *result, const std:: } else if (device_type == base_device::GpuDevice){ #ifdef __CUDA - hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); + ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>()(gpu_ctx, dim, result, vector1, constant1, vector2, constant2); #endif } } \ No newline at end of file diff --git a/source/module_hsolver/kernels/cuda/math_kernel_op.cu b/source/module_base/kernels/cuda/math_kernel_op.cu similarity index 97% rename from source/module_hsolver/kernels/cuda/math_kernel_op.cu rename to source/module_base/kernels/cuda/math_kernel_op.cu index cd3ac41812..d48862ef33 100644 --- a/source/module_hsolver/kernels/cuda/math_kernel_op.cu +++ b/source/module_base/kernels/cuda/math_kernel_op.cu @@ -1,5 +1,5 @@ #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "module_base/tool_quit.h" @@ -9,7 +9,7 @@ #include #include -namespace hsolver +namespace ModuleBase { const int warp_size = 32; // const unsigned int full_mask = 0xffffffff; @@ -24,7 +24,7 @@ template <> struct GetTypeReal> { using type = double; /**< The return type specialization for std::complex. */ }; -namespace hsolver { +namespace ModuleBase { template struct GetTypeThrust { using type = T; @@ -817,6 +817,27 @@ void scal_op::operator()(const base_device::DEV cublasErrcheck(cublasZscal(cublas_handle, N, (double2*)alpha, (double2*)X, incx)); } +template <> +void gemm_op::operator()(const base_device::DEVICE_GPU* d, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const float* alpha, + const float* a, + const int& lda, + const float* b, + const int& ldb, + const float* beta, + float* c, + const int& ldc) +{ + cublasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + cublasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); + cublasErrcheck(cublasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); +} + template <> void gemm_op::operator()(const base_device::DEVICE_GPU* d, const char& transa, @@ -1060,4 +1081,4 @@ template struct vector_div_vector_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif -} // namespace hsolver +} // namespace ModuleBase diff --git a/source/module_hsolver/kernels/math_kernel_op.cpp b/source/module_base/kernels/math_kernel_op.cpp similarity index 99% rename from source/module_hsolver/kernels/math_kernel_op.cpp rename to source/module_base/kernels/math_kernel_op.cpp index db2a12e9db..59a3c2ace8 100644 --- a/source/module_hsolver/kernels/math_kernel_op.cpp +++ b/source/module_base/kernels/math_kernel_op.cpp @@ -1,9 +1,9 @@ -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include -namespace hsolver +namespace ModuleBase { template diff --git a/source/module_hsolver/kernels/math_kernel_op.h b/source/module_base/kernels/math_kernel_op.h similarity index 99% rename from source/module_hsolver/kernels/math_kernel_op.h rename to source/module_base/kernels/math_kernel_op.h index 0daf0e5718..b525ce8467 100644 --- a/source/module_hsolver/kernels/math_kernel_op.h +++ b/source/module_base/kernels/math_kernel_op.h @@ -17,7 +17,7 @@ #include "cublas_v2.h" #endif //__CUDA || __UT_USE_CUDA -namespace hsolver { +namespace ModuleBase { inline std::complex set_real_tocomplex(const std::complex &x) { return {x.real(), 0.0}; diff --git a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu b/source/module_base/kernels/rocm/math_kernel_op.hip.cu similarity index 96% rename from source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu rename to source/module_base/kernels/rocm/math_kernel_op.hip.cu index 1993ae4c64..5ee0648e11 100644 --- a/source/module_hsolver/kernels/rocm/math_kernel_op.hip.cu +++ b/source/module_base/kernels/rocm/math_kernel_op.hip.cu @@ -1,5 +1,5 @@ #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "module_base/tool_quit.h" @@ -20,7 +20,7 @@ struct GetTypeReal> { using type = double; /**< The return type specialization for std::complex. */ }; -namespace hsolver { +namespace ModuleBase { template struct GetTypeThrust { @@ -735,6 +735,27 @@ void scal_op::operator()(const base_device::DEV hipblasErrcheck(hipblasZscal(cublas_handle, N, (hipblasDoubleComplex*)alpha, (hipblasDoubleComplex*)X, incx)); } +template <> +void gemm_op::operator()(const base_device::DEVICE_GPU* d, + const char& transa, + const char& transb, + const int& m, + const int& n, + const int& k, + const float* alpha, + const float* a, + const int& lda, + const float* b, + const int& ldb, + const float* beta, + float* c, + const int& ldc) +{ + hipblasOperation_t cutransA = judge_trans_op(false, transa, "gemm_op"); + hipblasOperation_t cutransB = judge_trans_op(false, transb, "gemm_op"); + hipblasErrcheck(hipblasSgemm(cublas_handle, cutransA, cutransB, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc)); +} + template <> void gemm_op::operator()(const base_device::DEVICE_GPU* d, const char& transa, @@ -968,4 +989,4 @@ template struct vector_div_vector_op; template struct matrixSetToAnother; template struct constantvector_addORsub_constantVector_op; #endif -} // namespace hsolver +} // namespace ModuleBase diff --git a/source/module_base/kernels/test/CMakeLists.txt b/source/module_base/kernels/test/CMakeLists.txt index 960de3b613..1453545d14 100644 --- a/source/module_base/kernels/test/CMakeLists.txt +++ b/source/module_base/kernels/test/CMakeLists.txt @@ -3,6 +3,5 @@ remove_definitions(-D__MPI) AddTest( TARGET Base_Kernels_UTs LIBS parameter ${math_libs} base device - SOURCES math_op_test.cpp + SOURCES math_op_test.cpp math_kernel_test.cpp ) - diff --git a/source/module_hsolver/kernels/test/math_kernel_test.cpp b/source/module_base/kernels/test/math_kernel_test.cpp similarity index 93% rename from source/module_hsolver/kernels/test/math_kernel_test.cpp rename to source/module_base/kernels/test/math_kernel_test.cpp index 0781d54787..caf320ef81 100644 --- a/source/module_hsolver/kernels/test/math_kernel_test.cpp +++ b/source/module_base/kernels/test/math_kernel_test.cpp @@ -1,7 +1,7 @@ #include "module_base/blas_connector.h" #include "module_base/constants.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -51,8 +51,8 @@ class TestModuleHsolverMathKernel : public ::testing::Test { } - using zdot_real_cpu_op = hsolver::dot_real_op, base_device::DEVICE_CPU>; - using zdot_real_gpu_op = hsolver::dot_real_op, base_device::DEVICE_GPU>; + using zdot_real_cpu_op = ModuleBase::dot_real_op, base_device::DEVICE_CPU>; + using zdot_real_gpu_op = ModuleBase::dot_real_op, base_device::DEVICE_GPU>; using resize_memory_op = base_device::memory::resize_memory_op, base_device::DEVICE_GPU>; using delete_memory_op = base_device::memory::delete_memory_op, base_device::DEVICE_GPU>; @@ -72,23 +72,23 @@ class TestModuleHsolverMathKernel : public ::testing::Test // haozhihan add // cpu operator - using vector_div_constant_op_cpu = hsolver::vector_div_constant_op, base_device::DEVICE_CPU>; - using vector_mul_vector_op_cpu = hsolver::vector_mul_vector_op, base_device::DEVICE_CPU>; - using vector_div_vector_op_cpu = hsolver::vector_div_vector_op, base_device::DEVICE_CPU>; + using vector_div_constant_op_cpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_CPU>; + using vector_mul_vector_op_cpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_CPU>; + using vector_div_vector_op_cpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_CPU>; using constantvector_addORsub_constantVector_op_cpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; - using axpy_op_cpu = hsolver::axpy_op, base_device::DEVICE_CPU>; - using scal_op_cpu = hsolver::scal_op; - using gemv_op_cpu = hsolver::gemv_op, base_device::DEVICE_CPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; + using axpy_op_cpu = ModuleBase::axpy_op, base_device::DEVICE_CPU>; + using scal_op_cpu = ModuleBase::scal_op; + using gemv_op_cpu = ModuleBase::gemv_op, base_device::DEVICE_CPU>; // gpu operator - using vector_div_constant_op_gpu = hsolver::vector_div_constant_op, base_device::DEVICE_GPU>; - using vector_mul_vector_op_gpu = hsolver::vector_mul_vector_op, base_device::DEVICE_GPU>; - using vector_div_vector_op_gpu = hsolver::vector_div_vector_op, base_device::DEVICE_GPU>; + using vector_div_constant_op_gpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_GPU>; + using vector_mul_vector_op_gpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_GPU>; + using vector_div_vector_op_gpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_GPU>; using constantvector_addORsub_constantVector_op_gpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; - using axpy_op_gpu = hsolver::axpy_op, base_device::DEVICE_GPU>; - using scal_op_gpu = hsolver::scal_op; - using gemv_op_gpu = hsolver::gemv_op, base_device::DEVICE_GPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; + using axpy_op_gpu = ModuleBase::axpy_op, base_device::DEVICE_GPU>; + using scal_op_gpu = ModuleBase::scal_op; + using gemv_op_gpu = ModuleBase::gemv_op, base_device::DEVICE_GPU>; // haozhihan add std::vector> L = {{-0.65412617, -0.74208893}, @@ -375,9 +375,9 @@ TEST_F(TestModuleHsolverMathKernel, zdot_real_op_gpu) resize_memory_op()(psi_R_dev, psi_R.size()); synchronize_memory_op()(psi_L_dev, psi_L.data(), psi_L.size()); synchronize_memory_op()(psi_R_dev, psi_R.data(), psi_R.size()); - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); double result = zdot_real_gpu_op()(gpu_ctx, dim, psi_L_dev, psi_R_dev, false); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); EXPECT_LT(fabs(result - expected_result), 1e-12); delete_memory_op()(psi_L_dev); delete_memory_op()(psi_R_dev); @@ -537,9 +537,9 @@ TEST_F(TestModuleHsolverMathKernel, axpy_op_gpu) synchronize_memory_op()(Y_axpy_dev, Y_axpy.data(), Y_axpy.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); axpy_op_gpu()(gpu_ctx, dim, &alpha_axpy, X_axpy_dev, 1, Y_axpy_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(Y_axpy.data(), Y_axpy_dev, Y_axpy.size()); @@ -566,9 +566,9 @@ TEST_F(TestModuleHsolverMathKernel, scal_op_gpu) synchronize_memory_op()(X_scal_dev, X_scal.data(), X_scal.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); scal_op_gpu()(gpu_ctx, dim, &alpha_scal, X_scal_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(X_scal.data(), X_scal_dev, X_scal.size()); @@ -599,9 +599,9 @@ TEST_F(TestModuleHsolverMathKernel, gemv_op_gpu) synchronize_memory_op()(Y_gemv_dev, Y_gemv.data(), Y_gemv.size()); // run - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); gemv_op_gpu()(gpu_ctx, 'C', 2, 3, &ModuleBase::ONE, A_gemv_dev, 2, X_gemv_dev, 1, &ModuleBase::ONE, Y_gemv_dev, 1); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // syn the output data in GPU to CPU synchronize_memory_op_gpu()(Y_gemv.data(), Y_gemv_dev, Y_gemv.size()); @@ -668,7 +668,7 @@ TEST_F(TestModuleHsolverMathKernel, matrixSetToAnother_op_gpu) B.size()); // run - hsolver::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, + ModuleBase::matrixSetToAnother, base_device::DEVICE_GPU>()(gpu_ctx, n, device_A, LDA, @@ -683,7 +683,7 @@ TEST_F(TestModuleHsolverMathKernel, matrixSetToAnother_op_gpu) B_gpu2cpu.size()); std::vector> B_cpu(8); - hsolver::matrixSetToAnother, base_device::DEVICE_CPU>()(cpu_ctx, + ModuleBase::matrixSetToAnother, base_device::DEVICE_CPU>()(cpu_ctx, n, A.data(), LDA, diff --git a/source/module_base/para_gemm.cpp b/source/module_base/para_gemm.cpp new file mode 100644 index 0000000000..0908457108 --- /dev/null +++ b/source/module_base/para_gemm.cpp @@ -0,0 +1,239 @@ +#include "para_gemm.h" + +#include "kernels/math_kernel_op.h" +#include "parallel_device.h" +namespace ModuleBase +{ +template +PGemmCN::PGemmCN() +{ +} +template +PGemmCN::~PGemmCN() +{ +} + +template +void PGemmCN::set_dimension( +#ifdef __MPI + MPI_Comm comm_col, + MPI_Comm comm_row, +#endif + const int ncolA_in, + const int LDA_in, + const int ncolB_in, + const int LDB_in, + const int nrow_in, + const int LDC_in, + const bool gatherC_in) +{ +#ifdef __MPI + MPI_Comm_rank(comm_col, &col_rank); + MPI_Comm_size(comm_col, &col_nproc); + if (comm_row != MPI_COMM_NULL) + { + MPI_Comm_rank(comm_row, &row_rank); + MPI_Comm_size(comm_row, &row_nproc); + } + col_world = comm_col; + row_world = comm_row; +#endif + this->LDA = LDA_in; + this->LDB = LDB_in; + this->LDC = LDC_in; + this->ncolA = ncolA_in; + this->ncolB = ncolB_in; + this->nrow = nrow_in; +#ifdef __MPI + this->gatherC = gatherC_in; + requests.resize(col_nproc); + colA_loc.resize(col_nproc); + MPI_Allgather(&ncolA, 1, MPI_INT, colA_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + max_colA = std::max(max_colA, colA_loc[ip]); + } + + if (this->gatherC) + { + colB_loc.resize(col_nproc); + recv_counts.resize(col_nproc); + displs.resize(col_nproc); + MPI_Allgather(&ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, col_world); + for (int ip = 0; ip < col_nproc; ip++) + { + recv_counts[ip] = LDC * colB_loc[ip]; + } + displs[0] = 0; + for (int ip = 1; ip < col_nproc; ip++) + { + displs[ip] = displs[ip - 1] + recv_counts[ip - 1]; + } + size_C_global = displs[col_nproc - 1] + recv_counts[col_nproc - 1]; + } + size_C_local = ncolB * LDC; +#endif +} + +template +void PGemmCN::multiply(const T alpha, const T* A, const T* B, const T beta, T* C) +{ + const Device* ctx = {}; +#ifdef __MPI + if (col_nproc > 1) + { + std::vector A_tmp(max_colA * LDA); + for (int ip = 0; ip < col_nproc; ip++) + { + if (col_rank != ip) + { + int size = ncolA * LDA; + Parallel_Common::isend_dev(A, size, ip, 0, col_world, &requests[ip], A_tmp.data()); + } + } + + T* C_local = C; + std::vector C_tmp; + if (this->gatherC) + { + C_tmp.resize(size_C_local); + if (std::is_same::value) + { + C_local = nullptr; + resmem_dev_op()(C_local, size_C_local); + } + else + { + C_local = C_tmp.data(); + } + syncmem_dev_op()(C_local, C + displs[col_rank], size_C_local); + } + + T* Atmp_device = nullptr; + if (std::is_same::value) + { + resmem_dev_op()(Atmp_device, max_colA * LDA); + } + else + { + Atmp_device = A_tmp.data(); + } + + int shift = 0; + T real_beta = row_rank == 0 ? beta : 0; + for (int ip = 0; ip < col_nproc; ip++) + { + T* C_start = C_local + shift; + if (col_rank == ip) + { + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + ncolA, + ncolB, + nrow, + &alpha, + A, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += ncolA; + } + else + { + int m = colA_loc[ip]; + int size = m * LDA; + MPI_Status status; + Parallel_Common::recv_dev(Atmp_device, size, ip, 0, col_world, &status, A_tmp.data()); + MPI_Wait(&requests[ip], &status); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + m, + ncolB, + nrow, + &alpha, + Atmp_device, + LDA, + B, + LDB, + &real_beta, + C_start, + LDC); + shift += m; + } + } + + if (this->gatherC) + { + T* Cglobal_cpu = nullptr; + T* Clocal_cpu = C_tmp.data();; + if (std::is_same::value) + { + delmem_dev_op()(Atmp_device); + + syncmem_d2h_op()(Clocal_cpu, C_local, size_C_local); + delmem_dev_op()(C_local); + + resmem_dev_op()(Cglobal_cpu, size_C_global); + } + else + { + Cglobal_cpu = C; + } + if (this->row_nproc > 1) + { + Parallel_Common::reduce_data(Clocal_cpu, size_C_local, row_world); + } + Parallel_Common::gatherv_data(Clocal_cpu, + size_C_local, + Cglobal_cpu, + recv_counts.data(), + displs.data(), + col_world); + + if (std::is_same::value) + { + syncmem_h2d_op()(C, Cglobal_cpu, size_C_global); + delmem_dev_op()(Cglobal_cpu); + } + } + else + { + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } + } + } + else + { + T real_beta = row_rank == 0 ? beta : 0; +#else + T real_beta = beta; +#endif + ModuleBase::gemm_op()(ctx, 'C', 'N', ncolA, ncolB, nrow, &alpha, A, LDA, B, LDB, &real_beta, C, LDC); +#ifdef __MPI + if (this->row_nproc > 1) + { + Parallel_Common::reduce_dev(C, size_C_local, row_world); + } + } +#endif +} + +template class PGemmCN; +template class PGemmCN; +template class PGemmCN, base_device::DEVICE_CPU>; +template class PGemmCN, base_device::DEVICE_CPU>; +#if ((defined __CUDA) || (defined __ROCM)) +template class PGemmCN; +template class PGemmCN; +template class PGemmCN, base_device::DEVICE_GPU>; +template class PGemmCN, base_device::DEVICE_GPU>; +#endif + +} // namespace ModuleBase \ No newline at end of file diff --git a/source/module_base/para_gemm.h b/source/module_base/para_gemm.h new file mode 100644 index 0000000000..69ffd6d146 --- /dev/null +++ b/source/module_base/para_gemm.h @@ -0,0 +1,93 @@ +#ifndef PARA_GEMM_H +#define PARA_GEMM_H +#include "module_base/module_device/device.h" +#include "module_base/module_device/memory_op.h" + +#include +#ifdef __MPI +#include "mpi.h" +#endif + +namespace ModuleBase +{ +/** + * @brief this class is used to perform parallel matrix multiplication + * C = alpha * A^H * B + beta * C + * Here, A and B are local matrices in each proc, + * C can be C_local or C_global, depending on the value of gatherC + * C_local is a local matrix in each proc + * C_global is a global matrix gathered from all procs and all procs have their own C_global matrix with the same + * C_global and C_local have the same LDC, but different column numbers + * values. + */ +template +class PGemmCN +{ + public: + PGemmCN(); + ~PGemmCN(); + + /** + * @brief set the dimension of A, B, and C + * + * @param ncolA number of columns of A, which is a local matrix in each proc + * @param LDA leading dimension of A in each proc + * @param ncolB number of columns of B, which is a local matrix in each proc + * @param LDB leading dimension of B in each proc + * @param nrow number of rows of A or B + * @param LDC leading dimension of C. C can be C_local or C_global + * @param gatherC whether gather C_local to C_global + */ + void set_dimension( +#ifdef __MPI + MPI_Comm comm_col, + MPI_Comm comm_row, +#endif + const int ncolA, + const int LDA, + const int ncolB, + const int LDB, + const int nrow, + const int LDC, + const bool gatherC = true); + + /** + * @brief calculate C = alpha * A^H * B + beta * C + * + */ + void multiply(const T alpha, const T* A, const T* B, const T beta, T* C); +#ifdef __MPI + MPI_Comm col_world = MPI_COMM_NULL; ///< column communicator world + MPI_Comm row_world = MPI_COMM_NULL; ///< row communicator world + + int col_rank = 0; ///< rank in col_world + int col_nproc = 1; ///< number of procs in col_world + int row_rank = 0; ///< rank in row_world + int row_nproc = 1; ///< number of procs in row_world + + std::vector colA_loc; ///< [col_nproc] number of columns of A matrix in each proc + int max_colA = 0; ///< maximum number of columns of A matrix in all procs + std::vector colB_loc; ///<[col_nproc] number of columns of B matrix in each proc + + std::vector requests; ///< MPI request + std::vector recv_counts; ///< receive counts for gathering C_local to C_global + std::vector displs; ///< displacements for gathering C_local to C_global + int size_C_local = 0; ///< size of C_local, which is a local matrix in each proc + int size_C_global = 0; ///< size of C_global, which is the global C matrix gathered from all procs + bool gatherC = true; ///< whether gather C_local to C_global +#endif + int ncolA = 0; ///< number of columns of A, which is a local matrix in each proc + int ncolB = 0; ///< number of columns of B, which is a local matrix in each proc + int nrow = 0; ///< number of rows of A or B + int LDA = 0; ///< leading dimension of A in each proc + int LDB = 0; ///< leading dimension of B in each proc + int LDC = 0; ///< leading dimension of C, which can be C_local or C_global + private: + using resmem_dev_op = base_device::memory::resize_memory_op; + using delmem_dev_op = base_device::memory::delete_memory_op; + using syncmem_dev_op = base_device::memory::synchronize_memory_op; + using syncmem_d2h_op = base_device::memory::synchronize_memory_op; + using syncmem_h2d_op = base_device::memory::synchronize_memory_op; +}; +} // namespace ModuleBase +#endif \ No newline at end of file diff --git a/source/module_base/parallel_device.cpp b/source/module_base/parallel_device.cpp index 269a41821e..d7373674d6 100644 --- a/source/module_base/parallel_device.cpp +++ b/source/module_base/parallel_device.cpp @@ -2,6 +2,38 @@ #ifdef __MPI namespace Parallel_Common { +void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_DOUBLE, dest, tag, comm, request); +} +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_DOUBLE_COMPLEX, dest, tag, comm, request); +} +void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_FLOAT, dest, tag, comm, request); +} +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request) +{ + MPI_Isend(buf, count, MPI_COMPLEX, dest, tag, comm, request); +} +void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_DOUBLE, source, tag, comm, status); +} +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_DOUBLE_COMPLEX, source, tag, comm, status); +} +void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_FLOAT, source, tag, comm, status); +} +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status) +{ + MPI_Recv(buf, count, MPI_COMPLEX, source, tag, comm, status); +} void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm) { MPI_Bcast(object, n * 2, MPI_DOUBLE, 0, comm); @@ -34,5 +66,95 @@ void reduce_data(float* object, const int& n, const MPI_Comm& comm) { MPI_Allreduce(MPI_IN_PLACE, object, n, MPI_FLOAT, MPI_SUM, comm); } +void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_DOUBLE, recvbuf, recvcounts, displs, MPI_DOUBLE, comm); +} +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_DOUBLE_COMPLEX, recvbuf, recvcounts, displs, MPI_DOUBLE_COMPLEX, comm); +} +void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_FLOAT, recvbuf, recvcounts, displs, MPI_FLOAT, comm); +} +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm) +{ + MPI_Allgatherv(sendbuf, sendcount, MPI_COMPLEX, recvbuf, recvcounts, displs, MPI_COMPLEX, comm); } + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + T* object_cpu = nullptr; + alloc = false; + + if (tmp_space == nullptr) + { + base_device::memory::resize_memory_op()(object_cpu, n); + alloc = true; + } + else + { + object_cpu = tmp_space; + } + base_device::memory::synchronize_memory_op()(object_cpu, + object, + n); + + return object_cpu; + } + void sync_h2d(T* object, const T* object_cpu, const int& n) + { + base_device::memory::synchronize_memory_op()(object, + object_cpu, + n); + } + void sync_d2h(T* object_cpu, const T* object, const int& n) + { + base_device::memory::synchronize_memory_op()(object_cpu, + object, + n); + } + void del(T* object_cpu) + { + if (alloc) + { + base_device::memory::delete_memory_op()(object_cpu); + } + } +}; + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr) + { + return const_cast(object); + } + void sync_h2d(T* object, const T* object_cpu, const int& n) + { + } + void sync_d2h(T* object_cpu, const T* object, const int& n) + { + } + void del(T* object_cpu) + { + } +}; + +template struct object_cpu_point; +template struct object_cpu_point; +template struct object_cpu_point, base_device::DEVICE_CPU>; +template struct object_cpu_point, base_device::DEVICE_GPU>; +template struct object_cpu_point; +template struct object_cpu_point; +template struct object_cpu_point, base_device::DEVICE_CPU>; +template struct object_cpu_point, base_device::DEVICE_GPU>; + +} // namespace Parallel_Common #endif \ No newline at end of file diff --git a/source/module_base/parallel_device.h b/source/module_base/parallel_device.h index 7c41b8f28f..776de4e755 100644 --- a/source/module_base/parallel_device.h +++ b/source/module_base/parallel_device.h @@ -7,6 +7,14 @@ #include namespace Parallel_Common { +void isend_data(const double* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const float* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void isend_data(const std::complex* buf, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request); +void recv_data(double* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(float* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); +void recv_data(std::complex* buf, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status); void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); void bcast_data(std::complex* object, const int& n, const MPI_Comm& comm); void bcast_data(double* object, const int& n, const MPI_Comm& comm); @@ -15,6 +23,50 @@ void reduce_data(std::complex* object, const int& n, const MPI_Comm& com void reduce_data(std::complex* object, const int& n, const MPI_Comm& comm); void reduce_data(double* object, const int& n, const MPI_Comm& comm); void reduce_data(float* object, const int& n, const MPI_Comm& comm); +void gatherv_data(const double* sendbuf, int sendcount, double* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const float* sendbuf, int sendcount, float* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); +void gatherv_data(const std::complex* sendbuf, int sendcount, std::complex* recvbuf, const int* recvcounts, const int* displs, MPI_Comm& comm); + +template +struct object_cpu_point +{ + bool alloc = false; + T* get(const T* object, const int& n, T* tmp_space = nullptr); + void del(T* object); + void sync_d2h(T* object_cpu, const T* object, const int& n); + void sync_h2d(T* object, const T* object_cpu, const int& n); +}; + +/** + * @brief isend data in Device + * + */ +template +void isend_dev(const T* object, int count, int dest, int tag, MPI_Comm& comm, MPI_Request* request, T* tmp_space = nullptr) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, tmp_space); + o.sync_d2h(object_cpu, object, count); + isend_data(object_cpu, count, dest, tag, comm, request); + o.del(object_cpu); + return; +} + +/** + * @brief recv data in Device + * + */ +template +void recv_dev(T* object, int count, int source, int tag, MPI_Comm& comm, MPI_Status* status, T* tmp_space = nullptr) +{ + object_cpu_point o; + T* object_cpu = o.get(object, count, tmp_space); + recv_data(object_cpu, count, source, tag, comm, status); + o.sync_h2d(object, object_cpu, count); + o.del(object_cpu); + return; +} /** * @brief bcast data in Device @@ -28,79 +80,28 @@ void reduce_data(float* object, const int& n, const MPI_Comm& comm); * @param tmp_space tmp space in CPU */ template -void bcast_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void bcast_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { - const base_device::DEVICE_CPU* cpu_ctx = {}; - T* object_cpu = nullptr; - bool alloc = false; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - if(tmp_space == nullptr) - { - base_device::memory::resize_memory_op()(object_cpu, n); - alloc = true; - } - else - { - object_cpu = tmp_space; - } - base_device::memory::synchronize_memory_op()(object_cpu, object, n); - } - else - { - object_cpu = object; - } - + object_cpu_point o; + T* object_cpu = o.get(object, n, tmp_space); + o.sync_d2h(object_cpu, object, n); bcast_data(object_cpu, n, comm); - - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - base_device::memory::synchronize_memory_op()(object, object_cpu, n); - if(alloc) - { - base_device::memory::delete_memory_op()(object_cpu); - } - } + o.sync_h2d(object, object_cpu, n); + o.del(object_cpu); return; } template -void reduce_dev(const Device* ctx, T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) +void reduce_dev(T* object, const int& n, const MPI_Comm& comm, T* tmp_space = nullptr) { - const base_device::DEVICE_CPU* cpu_ctx = {}; - T* object_cpu = nullptr; - bool alloc = false; - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - if(tmp_space == nullptr) - { - base_device::memory::resize_memory_op()(object_cpu, n); - alloc = true; - } - else - { - object_cpu = tmp_space; - } - base_device::memory::synchronize_memory_op()(object_cpu, object, n); - } - else - { - object_cpu = object; - } - + object_cpu_point o; + T* object_cpu = o.get(object, n, tmp_space); + o.sync_d2h(object_cpu, object, n); reduce_data(object_cpu, n, comm); - - if (base_device::get_device_type(ctx) == base_device::GpuDevice) - { - base_device::memory::synchronize_memory_op()(object, object_cpu, n); - if(alloc) - { - base_device::memory::delete_memory_op()(object_cpu); - } - } + o.sync_h2d(object, object_cpu, n); + o.del(object_cpu); return; } - } diff --git a/source/module_base/test_parallel/CMakeLists.txt b/source/module_base/test_parallel/CMakeLists.txt index f6a2c34c50..5132549f7a 100644 --- a/source/module_base/test_parallel/CMakeLists.txt +++ b/source/module_base/test_parallel/CMakeLists.txt @@ -34,6 +34,17 @@ add_test(NAME base_parallel_reduce_test WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} ) +AddTest( + TARGET base_para_gemm + LIBS MPI::MPI_CXX ${math_libs} base device parameter + SOURCES test_para_gemm.cpp +) + +add_test(NAME base_para_gemm_parallel + COMMAND mpirun -np 4 ./base_para_gemm + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} +) + AddTest( TARGET parallel_2d_test SOURCES parallel_2d_test.cpp ../parallel_2d.cpp diff --git a/source/module_base/test_parallel/test_para_gemm.cpp b/source/module_base/test_parallel/test_para_gemm.cpp new file mode 100644 index 0000000000..4b6445d057 --- /dev/null +++ b/source/module_base/test_parallel/test_para_gemm.cpp @@ -0,0 +1,466 @@ +#include "../kernels/math_kernel_op.h" +#include "../para_gemm.h" + +#include +#include +#include +#include + +void random_data(std::vector& A_global, + std::vector& B_global, + std::vector& Cref_global, + std::vector& C_global, + double& alpha, + double& beta) +{ + for (auto& val: A_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: B_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + for (auto& val: Cref_global) + { + val = std::rand() / (RAND_MAX + 1.0); + } + C_global = Cref_global; + + alpha = std::rand() / (RAND_MAX + 1.0); + beta = std::rand() / (RAND_MAX + 1.0); +} +void random_data(std::vector>& A_global, + std::vector>& B_global, + std::vector>& Cref_global, + std::vector>& C_global, + std::complex& alpha, + std::complex& beta) +{ + for (auto& val: A_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: B_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + for (auto& val: Cref_global) + { + val = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + } + C_global = Cref_global; + + alpha = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); + beta = std::complex(std::rand() / (RAND_MAX + 1.0), std::rand() / (RAND_MAX + 1.0)); +} +double get_double(std::complex& val) +{ + return val.real() + val.imag(); +} +double get_double(double& val) +{ + return val; +} + +void scatterv_data(const double* sendbuf, + const int* sendcounts, + const int* displs, + double* recvbuf, + const int recvcount, + MPI_Comm comm) +{ + MPI_Scatterv(sendbuf, sendcounts, displs, MPI_DOUBLE, recvbuf, recvcount, MPI_DOUBLE, 0, comm); +} +void scatterv_data(const std::complex* sendbuf, + const int* sendcounts, + const int* displs, + std::complex* recvbuf, + const int recvcount, + MPI_Comm comm) +{ + MPI_Scatterv(sendbuf, sendcounts, displs, MPI_DOUBLE_COMPLEX, recvbuf, recvcount, MPI_DOUBLE_COMPLEX, 0, comm); +} +template +class PgemmTest : public ::testing::Test +{ + protected: + void SetUp() override + { + MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_size(MPI_COMM_WORLD, &nproc); + } + void TearDown() override + { + MPI_Comm_free(&col_world); + MPI_Comm_free(&row_world); + } + + public: + void decide_ngroup(const int& willing_ncolgroup, const int& willing_nrowgroup) + { + ncolgroup = willing_ncolgroup; + nrowgroup = willing_nrowgroup; + if (nproc % (ncolgroup * nrowgroup) != 0) + { + ncolgroup = nproc; + nrowgroup = 1; + } + else + { + nrowgroup = nproc / ncolgroup; + } + + MPI_Comm_split(MPI_COMM_WORLD, rank % nrowgroup, rank / nrowgroup, &col_world); + MPI_Comm_split(MPI_COMM_WORLD, rank / nrowgroup, rank % nrowgroup, &row_world); + MPI_Comm_rank(col_world, &rank_col); + MPI_Comm_rank(row_world, &rank_row); + MPI_Comm_size(col_world, &nproc_col); + MPI_Comm_size(row_world, &nproc_row); + } + void randomize_initialization() + { + random_data(A_global, B_global, Cref_global, C_global, alpha, beta); + } + + void prepare(const int& ncolA_global, + const int& ncolB_global, + const int& nrow_global, + const int& LDA_global, + const int& LDB_global, + const int& LDC_global) + { + A_global = std::vector(LDA_global * ncolA_global, 0.0); + B_global = std::vector(LDB_global * ncolB_global, 0.0); + C_global = std::vector(LDC_global * ncolB_global, 0.0); + Cref_global = std::vector(LDC_global * ncolB_global, 0.0); + if (rank == 0) + { + + this->randomize_initialization(); + const base_device::DEVICE_CPU* ctx = {}; + char transC = 'C'; + char transN = 'N'; + ModuleBase::gemm_op()(ctx, + transC, + transN, + ncolA_global, + ncolB_global, + nrow_global, + &alpha, + A_global.data(), + LDA_global, + B_global.data(), + LDB_global, + &beta, + Cref_global.data(), + LDC_global); + } + + if (std::is_same::value) + { + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(C_global.data(), C_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(Cref_global.data(), Cref_global.size(), MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE, 0, MPI_COMM_WORLD); + } + else if (std::is_same>::value) + { + MPI_Bcast(A_global.data(), A_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(B_global.data(), B_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(C_global.data(), C_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(Cref_global.data(), Cref_global.size(), MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&alpha, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + MPI_Bcast(&beta, 1, MPI_DOUBLE_COMPLEX, 0, MPI_COMM_WORLD); + } + + // Broadcast A_global and B_global to all ranks + getncol_and_row(ncolA_global, ncolB_global, nrow_global); + LDA = nrow + 1; + LDB = nrow + 2; + + A_local = std::vector(LDA * ncolA, 0.0); + B_local = std::vector(LDB * ncolB, 0.0); + + scatter_matrix(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global); + } + + void getncol_and_row(const int& ncolA_global, const int& ncolB_global, const int& nrow_global) + { + ncolA = ncolA_global / ncolgroup; + if (ncolA_global % ncolgroup > rank_col) + { + ncolA += 1; + } + ncolB = ncolB_global / ncolgroup; + if (ncolB_global % ncolgroup > rank_col) + { + ncolB += 1; + } + + nrow = nrow_global / nrowgroup; + if (nrow_global % nrowgroup > rank_row) + { + nrow += 1; + } + + ncolA_ip.resize(nproc_col); + ncolB_ip.resize(nproc_col); + nrow_ip.resize(nproc_row); + MPI_Allgather(&ncolA, 1, MPI_INT, ncolA_ip.data(), 1, MPI_INT, col_world); + MPI_Allgather(&ncolB, 1, MPI_INT, ncolB_ip.data(), 1, MPI_INT, col_world); + if (row_world != MPI_COMM_NULL) + { + MPI_Allgather(&nrow, 1, MPI_INT, nrow_ip.data(), 1, MPI_INT, row_world); + } + } + + void scatter_matrix(const int& ncolA_global, + const int& ncolB_global, + const int& nrow_global, + const int& LDA_global, + const int& LDB_global) + { + std::vector A_semiglobal(ncolA * LDA_global, 0.0); + std::vector B_semiglobal(ncolB * LDB_global, 0.0); + + // Scatter A_global and B_global to A_semiglobal and B_semiglobal + std::vector sendcounts(nproc_col, 0); + std::vector displs(nproc_col, 0); + for (int i = 0; i < nproc_col; i++) + { + sendcounts[i] = ncolA_ip[i] * LDA_global; + } + displs[0] = 0; + for (int i = 1; i < nproc_col; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + scatterv_data(A_global.data(), + sendcounts.data(), + displs.data(), + A_semiglobal.data(), + ncolA * LDA_global, + col_world); + + for (int i = 0; i < nproc_col; i++) + { + sendcounts[i] = ncolB_ip[i] * LDB_global; + } + displs[0] = 0; + for (int i = 1; i < nproc_col; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + scatterv_data(B_global.data(), + sendcounts.data(), + displs.data(), + B_semiglobal.data(), + ncolB * LDB_global, + col_world); + + // Scatter A_semiglobal and B_semiglobal to A_local and B_local + sendcounts.resize(nproc_row, 0); + displs.resize(nproc_row, 0); + for (int i = 0; i < nproc_row; i++) + { + sendcounts[i] = nrow_ip[i]; + } + displs[0] = 0; + for (int i = 1; i < nproc_row; i++) + { + displs[i] = displs[i - 1] + sendcounts[i - 1]; + } + for (int i = 0; i < ncolA; i++) + { + scatterv_data(A_semiglobal.data() + i * LDA_global, + sendcounts.data(), + displs.data(), + A_local.data() + i * LDA, + nrow, + row_world); + } + + for (int i = 0; i < ncolB; i++) + { + scatterv_data(B_semiglobal.data() + i * LDB_global, + sendcounts.data(), + displs.data(), + B_local.data() + i * LDB, + nrow, + row_world); + } + } + + void compare_result(const int& nrowC_global, const int& ncolC_global, const int& LDC_global) + { + for (int i = 0; i < ncolC_global; i++) + { + for (int j = 0; j < nrowC_global; j++) + { + EXPECT_NEAR(get_double(Cref_global[i * LDC_global + j]), + get_double(C_global[i * LDC_global + j]), + 1e-10); + } + } + } + + int rank = 0, nproc = 0; + T alpha = 0, beta = 0; + std::vector A_global, B_global, Cref_global, C_global; + std::vector A_local, B_local; + int ncolA = 0, ncolB = 0, nrow = 0, LDA = 0, LDB = 0; + int ncolgroup = 1, nrowgroup = 1; + int rank_col = 0, rank_row = 0; + int nproc_col = 0, nproc_row = 0; + ModuleBase::PGemmCN pgemm; + MPI_Comm col_world; + MPI_Comm row_world; + std::vector ncolA_ip, ncolB_ip, nrow_ip; +}; + +typedef ::testing::Types> MyTypes; + +TYPED_TEST_SUITE(PgemmTest, MyTypes); + +TYPED_TEST(PgemmTest, even_case) +{ + const int ncolA_global = 16, ncolB_global = 8, nrow_global = 12; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, odd_case) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, odd_case_not_gather) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(2, 2); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + std::vector colB_loc(this->nproc_col); + MPI_Allgather(&this->ncolB, 1, MPI_INT, colB_loc.data(), 1, MPI_INT, this->col_world); + std::vector displs(this->nproc_col); + displs[0] = 0; + for (int i = 1; i < this->nproc_col; i++) + { + displs[i] = (displs[i - 1] + colB_loc[i - 1]) * LDC_global; + } + int start = displs[this->rank_col]; + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global, + false); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()+ start); + + + + for (int i = 0; i < this->ncolB; i++) + { + for (int j = 0; j < ncolA_global; j++) + { + EXPECT_NEAR(get_double(this->Cref_global[i * LDC_global + start + j]), + get_double(this->C_global[i * LDC_global + start + j]), + 1e-10); + } + } +} + +TYPED_TEST(PgemmTest, row_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(1, 4); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +TYPED_TEST(PgemmTest, col_parallel) +{ + const int ncolA_global = 17, ncolB_global = 7, nrow_global = 13; + const int LDA_global = 17, LDB_global = 18, LDC_global = 19; + + this->decide_ngroup(4, 1); + this->prepare(ncolA_global, ncolB_global, nrow_global, LDA_global, LDB_global, LDC_global); + + this->pgemm.set_dimension(this->col_world, + this->row_world, + this->ncolA, + this->LDA, + this->ncolB, + this->LDB, + this->nrow, + LDC_global); + this->pgemm.multiply(this->alpha, this->A_local.data(), this->B_local.data(), this->beta, this->C_global.data()); + + this->compare_result(ncolA_global, ncolB_global, LDC_global); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + MPI_Init(&argc, &argv); + + int RANK, NPROC; + MPI_Comm_rank(MPI_COMM_WORLD, &RANK); + MPI_Comm_size(MPI_COMM_WORLD, &NPROC); + + int result = RUN_ALL_TESTS(); + + MPI_Finalize(); + return result; +} \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw.h b/source/module_elecstate/elecstate_pw.h index 8259d83024..679b9b712c 100644 --- a/source/module_elecstate/elecstate_pw.h +++ b/source/module_elecstate/elecstate_pw.h @@ -7,7 +7,7 @@ #include "module_basis/module_pw/pw_basis_k.h" #include "module_elecstate/kernels/elecstate_op.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/meta_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace elecstate { @@ -98,8 +98,8 @@ class ElecStatePW : public ElecState using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; }; } // namespace elecstate diff --git a/source/module_esolver/esolver_ks_lcaopw.cpp b/source/module_esolver/esolver_ks_lcaopw.cpp index 08d1043a4a..257a638e68 100644 --- a/source/module_esolver/esolver_ks_lcaopw.cpp +++ b/source/module_esolver/esolver_ks_lcaopw.cpp @@ -28,7 +28,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_lcaopw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/numerical_basis.h" #include "module_io/numerical_descriptor.h" diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index a96d487a5c..454303ba1d 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -26,7 +26,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_pw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/cube_io.h" #include "module_io/get_pchg_pw.h" @@ -73,7 +73,7 @@ ESolver_KS_PW::ESolver_KS_PW() #if ((defined __CUDA) || (defined __ROCM)) if (this->device == base_device::GpuDevice) { - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); hsolver::createGpuSolverHandle(); container::kernels::createGpuBlasHandle(); container::kernels::createGpuSolverHandle(); @@ -101,7 +101,7 @@ ESolver_KS_PW::~ESolver_KS_PW() if (this->device == base_device::GpuDevice) { #if defined(__CUDA) || defined(__ROCM) - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); hsolver::destroyGpuSolverHandle(); container::kernels::destroyGpuBlasHandle(); container::kernels::destroyGpuSolverHandle(); diff --git a/source/module_esolver/pw_others.cpp b/source/module_esolver/pw_others.cpp index ef32f041e8..0f2be0a998 100644 --- a/source/module_esolver/pw_others.cpp +++ b/source/module_esolver/pw_others.cpp @@ -30,7 +30,7 @@ #include "module_hsolver/diago_iter_assist.h" #include "module_hsolver/hsolver_pw.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_io/berryphase.h" #include "module_io/numerical_basis.h" #include "module_io/numerical_descriptor.h" diff --git a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp index 36baed7bab..8f77f275e3 100644 --- a/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp +++ b/source/module_hamilt_lcao/module_deltaspin/cal_mw_from_lambda.cpp @@ -5,7 +5,7 @@ #include "spin_constrain.h" #include "module_hamilt_pw/hamilt_pwdft/onsite_projector.h" #include "module_base/parallel_reduce.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/hsolver_lcao.h" #include "module_hsolver/hsolver_pw.h" #include "module_elecstate/elecstate_pw.h" @@ -84,7 +84,7 @@ void spinconstrain::SpinConstrain>::calculate_delta_hcc(std { #if ((defined __CUDA) || (defined __ROCM)) base_device::DEVICE_GPU* ctx = {}; - hsolver::gemm_op, base_device::DEVICE_GPU>()( + ModuleBase::gemm_op, base_device::DEVICE_GPU>()( ctx, transa, transb, @@ -108,7 +108,7 @@ void spinconstrain::SpinConstrain>::calculate_delta_hcc(std else if (PARAM.inp.device == "cpu") { base_device::DEVICE_CPU* ctx = {}; - hsolver::gemm_op, base_device::DEVICE_CPU>()( + ModuleBase::gemm_op, base_device::DEVICE_CPU>()( ctx, transa, transb, diff --git a/source/module_hamilt_pw/hamilt_pwdft/forces.h b/source/module_hamilt_pw/hamilt_pwdft/forces.h index 695520dceb..5472396fcc 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/forces.h +++ b/source/module_hamilt_pw/hamilt_pwdft/forces.h @@ -11,7 +11,7 @@ #include "module_elecstate/elecstate.h" #include "module_hamilt_pw/hamilt_pwdft/VL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/force_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include "structure_factor.h" @@ -129,7 +129,7 @@ class Forces base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; private: - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using resmem_complex_op = base_device::memory::resize_memory_op, Device>; using resmem_complex_h_op = base_device::memory::resize_memory_op, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp index 523cb2b504..b2dc156560 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp @@ -295,7 +295,7 @@ void FS_Nonlocal_tools::reduce_pool_becp(const int& npm) #ifdef __MPI if (GlobalV::NPROC_IN_POOL > 1) { - Parallel_Common::reduce_dev(this->ctx, this->becp, size_becp_act, POOL_WORLD); + Parallel_Common::reduce_dev, Device>(this->becp, size_becp_act, POOL_WORLD); } #endif } diff --git a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h index 0cc640f27c..64a76e700d 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h +++ b/source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.h @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include @@ -215,7 +215,7 @@ class FS_Nonlocal_tools std::complex* becp = nullptr; // nbands * nkb /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h index f87dca7745..badeae0db6 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/hamilt_pw.h @@ -6,7 +6,7 @@ #include "module_elecstate/potentials/potential_new.h" #include "module_hamilt_general/hamilt.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { @@ -44,8 +44,8 @@ class HamiltPW : public Hamilt T* qq_so = nullptr; Device* ctx = {}; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp b/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp index 79649fab07..292ad80d43 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp +++ b/source/module_hamilt_pw/hamilt_pwdft/nonlocal_maths.hpp @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h b/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h index 17c7e06491..0376a9709f 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_proj_tools.h @@ -7,7 +7,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" #include @@ -178,7 +178,7 @@ class Onsite_Proj_tools std::complex* becp = nullptr; // nbands * nkb /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp index f235df15e5..499cd4c837 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.cpp @@ -8,7 +8,7 @@ #include "module_base/projgen.h" #include "module_base/blas_connector.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #ifdef __MPI #include "module_base/parallel_reduce.h" #include "module_base/parallel_common.h" diff --git a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h index a2bb99354b..b34d8291de 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h +++ b/source/module_hamilt_pw/hamilt_pwdft/onsite_projector.h @@ -1,7 +1,7 @@ #ifndef MODULEHAMILTPW_ONSITEPROJECTOR_H #define MODULEHAMILTPW_ONSITEPROJECTOR_H #include "module_base/module_device/device.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_hamilt_pw/hamilt_pwdft/radial_proj.h" @@ -130,7 +130,7 @@ namespace projectors bool initialed = false; /// @brief rename the operators for CPU/GPU device - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using resmem_complex_op = base_device::memory::resize_memory_op, Device>; using resmem_complex_h_op = base_device::memory::resize_memory_op, base_device::DEVICE_CPU>; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h index 21fc574f5b..133eed1f5b 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.h @@ -5,7 +5,7 @@ #include "module_base/matrix.h" #include "module_basis/module_pw/pw_basis_k.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/meta_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include @@ -81,7 +81,7 @@ class Meta> : public OperatorPW base_device::DEVICE_CPU* cpu_ctx = {}; T *porter = nullptr; using meta_op = meta_pw_op; - using vector_mul_vector_op = hsolver::vector_mul_vector_op; + using vector_mul_vector_op = ModuleBase::vector_mul_vector_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; using setmem_complex_op = base_device::memory::set_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h index 91e760920a..31a98d24c9 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/nonlocal_pw.h @@ -5,7 +5,7 @@ #include "module_cell/unitcell.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/nonlocal_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" @@ -85,8 +85,8 @@ class Nonlocal> : public OperatorPW Real * deeq = nullptr; T * deeq_nc = nullptr; // using nonlocal_op = nonlocal_pw_op; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using nonlocal_op = nonlocal_pw_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h index 975967d5c8..b28657d0df 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/onsite_proj_pw.h @@ -4,7 +4,7 @@ #include "operator_pw.h" #include "module_cell/unitcell.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hamilt { @@ -76,8 +76,8 @@ class OnsiteProj> : public OperatorPW Device* ctx = {}; base_device::DEVICE_CPU* cpu_ctx = {}; - using gemv_op = hsolver::gemv_op; - using gemm_op = hsolver::gemm_op; + using gemv_op = ModuleBase::gemv_op; + using gemm_op = ModuleBase::gemm_op; using setmem_complex_op = base_device::memory::set_memory_op; using resmem_complex_op = base_device::memory::resize_memory_op; using delmem_complex_op = base_device::memory::delete_memory_op; diff --git a/source/module_hamilt_pw/hamilt_pwdft/stress_func.h b/source/module_hamilt_pw/hamilt_pwdft/stress_func.h index 878206ad38..20f6a91937 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/stress_func.h +++ b/source/module_hamilt_pw/hamilt_pwdft/stress_func.h @@ -14,7 +14,7 @@ #include "module_hamilt_pw/hamilt_pwdft/VNL_in_pw.h" #include "module_hamilt_pw/hamilt_pwdft/kernels/stress_op.h" #include "module_hamilt_pw/hamilt_pwdft/structure_factor.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_psi/psi.h" //------------------------------------------------------------------- @@ -241,7 +241,7 @@ class Stress_Func base_device::DEVICE_CPU* cpu_ctx = {}; base_device::AbacusDevice_t device = {}; private: - using gemm_op = hsolver::gemm_op, Device>; + using gemm_op = ModuleBase::gemm_op, Device>; using cal_stress_nl_op = hamilt::cal_stress_nl_op; using cal_dbecp_noevc_nl_op = hamilt::cal_dbecp_noevc_nl_op; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp index 34e20977eb..9facef1ddf 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.cpp @@ -1,7 +1,7 @@ #include "sto_che.h" #include "module_base/blas_connector.h" #include "module_base/module_device/device.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_container/ATen/kernels/blas.h" template diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_che.h b/source/module_hamilt_pw/hamilt_stodft/sto_che.h index f241553b66..578e5df0fb 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_che.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_che.h @@ -1,7 +1,7 @@ #ifndef STO_CHE_H #define STO_CHE_H #include "module_base/math_chebyshev.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/module_container/ATen/kernels/blas.h" template @@ -51,7 +51,7 @@ REAL vTMv(const REAL* v, const REAL* M, const int n) const REAL zero = 0; REAL* y = nullptr; base_device::memory::resize_memory_op()(y, n); - hsolver::gemv_op()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc); + ModuleBase::gemv_op()(ctx, normal, n, n, &one, M, n, v, inc, &zero, y, inc); REAL result = 0; REAL* dot_device = nullptr; base_device::memory::resize_memory_op()(dot_device, 1); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 8ec669febd..bd029a401d 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -7,7 +7,7 @@ #include "module_elecstate/occupy.h" #include "module_hamilt_pw/hamilt_pwdft/global.h" #include "module_parameter/parameter.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_elecstate/kernels/elecstate_op.h" template @@ -78,7 +78,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, char transN = 'N'; // sum(b - hsolver::gemm_op()(ctx, + ModuleBase::gemm_op()(ctx, transC, transN, PARAM.inp.nbands, @@ -95,7 +95,7 @@ void Stochastic_Iter::orthog(const int& ik, psi::Psi& psi, Parallel_Reduce::reduce_pool(sum, PARAM.inp.nbands * nchipk); // psi -= psi * sum - hsolver::gemm_op()(ctx, + ModuleBase::gemm_op()(ctx, transN, transN, npw, @@ -406,7 +406,7 @@ void Stochastic_Iter::calPn(const int& ik, Stochastic_WF& const int N = norder; const Real kweight = this->pkv->wk[ik]; - hsolver::gemm_op()(this->ctx, trans, normal, N, N, M, &kweight, vec_all, LDA, vec_all, LDA, &one, spolyv, N); + ModuleBase::gemm_op()(this->ctx, trans, normal, N, N, M, &kweight, vec_all, LDA, vec_all, LDA, &one, spolyv, N); // dgemm_(&trans, &normal, &N, &N, &M, &kweight, vec_all, &LDA, vec_all, &LDA, &one, spolyv, &N); } ModuleBase::timer::tick("Stochastic_Iter", "calPn"); diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h index 901b1311f3..9953cfcd3b 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.h +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.h @@ -163,7 +163,7 @@ class Stochastic_Iter using delmem_complex_op = base_device::memory::delete_memory_op; using castmem_d2z_op = base_device::memory::cast_memory_op; using castmem_var_d2h_op = base_device::memory::cast_memory_op; - using gemv_op = hsolver::gemv_op; + using gemv_op = ModuleBase::gemv_op; }; #endif // Eelectrons_Iter diff --git a/source/module_hsolver/CMakeLists.txt b/source/module_hsolver/CMakeLists.txt index 93a708f21d..7f6c8ca4c6 100644 --- a/source/module_hsolver/CMakeLists.txt +++ b/source/module_hsolver/CMakeLists.txt @@ -36,7 +36,6 @@ if(ENABLE_LCAO) if(USE_CUDA) list(APPEND objects - ./kernels/math_kernel_op.cpp ./kernels/dngvd_op.cpp ./kernels/cuda/diag_cusolver.cu diago_cusolver.cpp diff --git a/source/module_hsolver/diago_bpcg.cpp b/source/module_hsolver/diago_bpcg.cpp index 846bef9ff8..36f77d372d 100644 --- a/source/module_hsolver/diago_bpcg.cpp +++ b/source/module_hsolver/diago_bpcg.cpp @@ -10,7 +10,7 @@ #include "diago_iter_assist.h" #include "module_base/blas_connector.h" #include "module_base/global_function.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hsolver { diff --git a/source/module_hsolver/diago_bpcg.h b/source/module_hsolver/diago_bpcg.h index a80c1406b6..90907de5e9 100644 --- a/source/module_hsolver/diago_bpcg.h +++ b/source/module_hsolver/diago_bpcg.h @@ -7,7 +7,7 @@ #include "module_base/module_device/types.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hsolver/kernels/dngvd_op.h" #include @@ -343,9 +343,9 @@ class DiagoBPCG // note: these operators use template parameter base_device::Device_* // defined in module_base/module_device/types.h // different from ct_Device! - using calc_grad_with_block_op = hsolver::calc_grad_with_block_op; - using line_minimize_with_block_op = hsolver::line_minimize_with_block_op; - using gemm_op = hsolver::gemm_op; + using calc_grad_with_block_op = ModuleBase::calc_grad_with_block_op; + using line_minimize_with_block_op = ModuleBase::line_minimize_with_block_op; + using gemm_op = ModuleBase::gemm_op; }; diff --git a/source/module_hsolver/diago_cg.cpp b/source/module_hsolver/diago_cg.cpp index 29bdffa977..ea872d6d3e 100644 --- a/source/module_hsolver/diago_cg.cpp +++ b/source/module_hsolver/diago_cg.cpp @@ -226,14 +226,14 @@ void DiagoCG::calc_grad(const ct::Tensor& prec, // } // denghui replace this at 20221106 // TODO: use GPU precondition to initialize CG class - vector_div_vector_op()(ctx_, this->n_basis_, grad.data(), hphi.data(), prec.data()); - vector_div_vector_op()(ctx_, this->n_basis_, pphi.data(), sphi.data(), prec.data()); + ModuleBase::vector_div_vector_op()(ctx_, this->n_basis_, grad.data(), hphi.data(), prec.data()); + ModuleBase::vector_div_vector_op()(ctx_, this->n_basis_, pphi.data(), sphi.data(), prec.data()); // Update lambda ! // (4) - const Real eh = hsolver::dot_real_op()(ctx_, this->n_basis_, sphi.data(), grad.data()); + const Real eh = ModuleBase::dot_real_op()(ctx_, this->n_basis_, sphi.data(), grad.data()); // (5) - const Real es = hsolver::dot_real_op()(ctx_, this->n_basis_, sphi.data(), pphi.data()); + const Real es = ModuleBase::dot_real_op()(ctx_, this->n_basis_, sphi.data(), pphi.data()); const Real lambda = eh / es; // Update g! @@ -247,13 +247,13 @@ void DiagoCG::calc_grad(const ct::Tensor& prec, // grad.data()[i] -= lambda * this->pphi[i]; // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - grad.data(), - grad.data(), - 1.0, - pphi.data(), - (-lambda)); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + grad.data(), + grad.data(), + 1.0, + pphi.data(), + (-lambda)); } template @@ -264,49 +264,49 @@ void DiagoCG::orth_grad(const ct::Tensor& psi, ct::Tensor& lagrange) { this->spsi_func_(grad, scg); // scg = S|grad> - gemv_op()(ctx_, - 'C', - this->n_basis_, - m, - this->one_, - psi.data(), - this->n_basis_, - scg.data(), - 1, - this->zero_, - lagrange.data(), - 1); + ModuleBase::gemv_op()(ctx_, + 'C', + this->n_basis_, + m, + this->one_, + psi.data(), + this->n_basis_, + scg.data(), + 1, + this->zero_, + lagrange.data(), + 1); Parallel_Reduce::reduce_pool(lagrange.data(), m); // (3) orthogonal |g> and |scg> to all states (0~m-1) //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-07 - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange.data(), - 1, - this->one_, - grad.data(), - 1); - - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange.data(), - 1, - this->one_, - scg.data(), - 1); + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange.data(), + 1, + this->one_, + grad.data(), + 1); + + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange.data(), + 1, + this->one_, + scg.data(), + 1); } template @@ -328,7 +328,7 @@ void DiagoCG::calc_gamma_cg(const int& iter, // gg_inter = // Attention : the 'g' in g0 is getted last time gg_inter - = hsolver::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); // b means before + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); // b means before } // (2) Update for g0! @@ -342,11 +342,11 @@ void DiagoCG::calc_gamma_cg(const int& iter, // } // denghui replace this 20221106 // TODO: use GPU precondition instead - vector_mul_vector_op()(ctx_, this->n_basis_, g0.data(), scg.data(), prec.data()); + ModuleBase::vector_mul_vector_op()(ctx_, this->n_basis_, g0.data(), scg.data(), prec.data()); // (3) Update gg_now! // gg_now = < g|P|scg > = < g|g0 > - const Real gg_now = hsolver::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); + const Real gg_now = ModuleBase::dot_real_op()(ctx_, this->n_basis_, grad.data(), g0.data()); if (iter == 0) { @@ -370,13 +370,13 @@ void DiagoCG::calc_gamma_cg(const int& iter, // pcg[i] = gamma * pcg[i] + grad.data()[i]; // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - cg.data(), - cg.data(), - gamma, - grad.data(), - 1.0); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + cg.data(), + cg.data(), + gamma, + grad.data(), + 1.0); const Real norma = gamma * cg_norm * sin(theta); T znorma = static_cast(norma * -1); @@ -388,7 +388,7 @@ void DiagoCG::calc_gamma_cg(const int& iter, { pcg[i] -= norma * pphi_m[i]; }*/ - axpy_op()(ctx_, this->n_basis_, &znorma, phi_m.data(), 1, cg.data(), 1); + ModuleBase::axpy_op()(ctx_, this->n_basis_, &znorma, phi_m.data(), 1, cg.data(), 1); } } @@ -404,15 +404,15 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, ct::Tensor& sphi, ct::Tensor& hphi) { - cg_norm = sqrt(hsolver::dot_real_op()(ctx_, this->n_basis_, cg.data(), scg.data())); + cg_norm = sqrt(ModuleBase::dot_real_op()(ctx_, this->n_basis_, cg.data(), scg.data())); if (cg_norm < 1.0e-10) return true; const Real a0 - = hsolver::dot_real_op()(ctx_, this->n_basis_, phi_m.data(), pphi.data()) * 2.0 / cg_norm; + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, phi_m.data(), pphi.data()) * 2.0 / cg_norm; const Real b0 - = hsolver::dot_real_op()(ctx_, this->n_basis_, cg.data(), pphi.data()) / (cg_norm * cg_norm); + = ModuleBase::dot_real_op()(ctx_, this->n_basis_, cg.data(), pphi.data()) / (cg_norm * cg_norm); const Real e0 = eigen; theta = atan(a0 / (e0 - b0)) / 2.0; @@ -438,13 +438,13 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - phi_m.data(), - phi_m.data(), - cost, - cg.data(), - sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + phi_m.data(), + phi_m.data(), + cost, + cg.data(), + sint_norm); if (std::abs(eigen - e0) < ethreshold) { @@ -460,20 +460,20 @@ bool DiagoCG::update_psi(const ct::Tensor& pphi, // } // haozhihan replace this 2022-10-6 - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - sphi.data(), - sphi.data(), - cost, - scg.data(), - sint_norm); - constantvector_addORsub_constantVector_op()(ctx_, - this->n_basis_, - hphi.data(), - hphi.data(), - cost, - pphi.data(), - sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + sphi.data(), + sphi.data(), + cost, + scg.data(), + sint_norm); + ModuleBase::constantvector_addORsub_constantVector_op()(ctx_, + this->n_basis_, + hphi.data(), + hphi.data(), + cost, + pphi.data(), + sint_norm); return false; } } @@ -496,36 +496,36 @@ void DiagoCG::schmit_orth(const int& m, const ct::Tensor& psi, const //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-6 int inc = 1; - gemv_op()(ctx_, - 'C', - this->n_basis_, - m + 1, - this->one_, - psi.data(), - this->n_basis_, - sphi.data(), - inc, - this->zero_, - lagrange_so.data(), - inc); + ModuleBase::gemv_op()(ctx_, + 'C', + this->n_basis_, + m + 1, + this->one_, + psi.data(), + this->n_basis_, + sphi.data(), + inc, + this->zero_, + lagrange_so.data(), + inc); // be careful , here reduce m+1 Parallel_Reduce::reduce_pool(lagrange_so.data(), m + 1); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // haozhihan replace 2022-10-6 - gemv_op()(ctx_, - 'N', - this->n_basis_, - m, - this->neg_one_, - psi.data(), - this->n_basis_, - lagrange_so.data(), - inc, - this->one_, - phi_m.data(), - inc); + ModuleBase::gemv_op()(ctx_, + 'N', + this->n_basis_, + m, + this->neg_one_, + psi.data(), + this->n_basis_, + lagrange_so.data(), + inc, + this->one_, + phi_m.data(), + inc); //====================================================================== /*for (int j = 0; j < m; j++) @@ -563,7 +563,7 @@ void DiagoCG::schmit_orth(const int& m, const ct::Tensor& psi, const // { // pphi_m[ig] /= psi_norm; // } - vector_div_constant_op()(ctx_, this->n_basis_, phi_m.data(), phi_m.data(), psi_norm); + ModuleBase::vector_div_constant_op()(ctx_, this->n_basis_, phi_m.data(), phi_m.data(), psi_norm); // ModuleBase::timer::tick("DiagoCG","schmit_orth"); } diff --git a/source/module_hsolver/diago_cg.h b/source/module_hsolver/diago_cg.h index 2741df42d4..9d254ded18 100644 --- a/source/module_hsolver/diago_cg.h +++ b/source/module_hsolver/diago_cg.h @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include @@ -126,7 +126,7 @@ class DiagoCG final bool test_exit_cond(const int& ntry, const int& notconv) const; - using dot_real_op = hsolver::dot_real_op; + using dot_real_op = ModuleBase::dot_real_op; const T * one_ = nullptr, * zero_ = nullptr, * neg_one_ = nullptr; }; diff --git a/source/module_hsolver/diago_dav_subspace.cpp b/source/module_hsolver/diago_dav_subspace.cpp index f7daf229a2..177e68847c 100644 --- a/source/module_hsolver/diago_dav_subspace.cpp +++ b/source/module_hsolver/diago_dav_subspace.cpp @@ -5,7 +5,7 @@ #include "module_base/module_device/device.h" #include "module_base/timer.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_base/kernels/dsp/dsp_connector.h" #include "module_hsolver/diag_hs_para.h" @@ -191,24 +191,25 @@ int Diago_DavSubspace::diag_once(const HPsiFunc& hpsi_func, setmem_complex_op()(psi_in, 0, n_band * psi_in_dmax); #ifdef __DSP - gemm_op_mt() // In order to not coding another whole template, using this method to minimize the code change. + ModuleBase::gemm_op_mt() // In order to not coding another whole template, using this method to + // minimize the code change. #else - gemm_op() + ModuleBase::gemm_op() #endif - (this->ctx, - 'N', - 'N', - this->dim, - this->n_band, - nbase, - this->one, - this->psi_in_iter, - this->dim, - this->vcc, - this->nbase_x, - this->zero, - psi_in, - psi_in_dmax); + (this->ctx, + 'N', + 'N', + this->dim, + this->n_band, + nbase, + this->one, + this->psi_in_iter, + this->dim, + this->vcc, + this->nbase_x, + this->zero, + psi_in, + psi_in_dmax); if (!this->notconv || (dav_iter == this->iter_nmax)) { @@ -275,9 +276,9 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'N', @@ -308,11 +309,11 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, { syncmem_var_h2d_op()(e_temp_hd, e_temp_cpu.data(), nbase); } - vector_mul_vector_op()(this->ctx, - nbase, - vcc + m * this->nbase_x, - vcc + m * this->nbase_x, - e_temp_hd); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vcc + m * this->nbase_x, + vcc + m * this->nbase_x, + e_temp_hd); } if(this->device == base_device::GpuDevice) { @@ -320,24 +321,24 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, } #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() -#endif - (this->ctx, - 'N', - 'N', - this->dim, - notconv, - nbase, - this->one, - psi_iter, - this->dim, - vcc, - this->nbase_x, - this->one, - psi_iter + nbase * this->dim, - this->dim); + ModuleBase::gemm_op() +#endif + (this->ctx, + 'N', + 'N', + this->dim, + notconv, + nbase, + this->one, + psi_iter, + this->dim, + vcc, + this->nbase_x, + this->one, + psi_iter + nbase * this->dim, + this->dim); // "precondition!!!" std::vector pre(this->dim, 0.0); @@ -353,20 +354,20 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, if (this->device == base_device::GpuDevice) { syncmem_var_h2d_op()(this->d_precondition, pre.data(), this->dim); - vector_div_vector_op()(this->ctx, - this->dim, - psi_iter + (nbase + m) * this->dim, - psi_iter + (nbase + m) * this->dim, - this->d_precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + this->dim, + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, + this->d_precondition); } else #endif { - vector_div_vector_op()(this->ctx, - this->dim, - psi_iter + (nbase + m) * this->dim, - psi_iter + (nbase + m) * this->dim, - pre.data()); + ModuleBase::vector_div_vector_op()(this->ctx, + this->dim, + psi_iter + (nbase + m) * this->dim, + psi_iter + (nbase + m) * this->dim, + pre.data()); } } @@ -374,19 +375,19 @@ void Diago_DavSubspace::cal_grad(const HPsiFunc& hpsi_func, std::vector psi_norm(notconv, 0.0); for (size_t i = 0; i < notconv; i++) { - psi_norm[i] = dot_real_op()(this->ctx, - this->dim, - psi_iter + (nbase + i) * this->dim, - psi_iter + (nbase + i) * this->dim, - true); + psi_norm[i] = ModuleBase::dot_real_op()(this->ctx, + this->dim, + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, + true); assert(psi_norm[i] > 0.0); psi_norm[i] = sqrt(psi_norm[i]); - vector_div_constant_op()(this->ctx, - this->dim, - psi_iter + (nbase + i) * this->dim, - psi_iter + (nbase + i) * this->dim, - psi_norm[i]); + ModuleBase::vector_div_constant_op()(this->ctx, + this->dim, + psi_iter + (nbase + i) * this->dim, + psi_iter + (nbase + i) * this->dim, + psi_norm[i]); } // update hpsi[:, nbase:nbase+notconv] @@ -409,9 +410,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem"); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'C', @@ -429,9 +430,9 @@ void Diago_DavSubspace::cal_elem(const int& dim, this->nbase_x); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'C', @@ -691,9 +692,9 @@ void Diago_DavSubspace::refresh(const int& dim, ModuleBase::timer::tick("Diago_DavSubspace", "refresh"); #ifdef __DSP - gemm_op_mt() + ModuleBase::gemm_op_mt() #else - gemm_op() + ModuleBase::gemm_op() #endif (this->ctx, 'N', diff --git a/source/module_hsolver/diago_david.cpp b/source/module_hsolver/diago_david.cpp index 6afaf998b8..21865a4ed1 100644 --- a/source/module_hsolver/diago_david.cpp +++ b/source/module_hsolver/diago_david.cpp @@ -5,7 +5,7 @@ #include "module_base/module_device/device.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #ifdef USE_PAW #include "module_cell/module_paw/paw_cell.h" @@ -266,21 +266,20 @@ int DiagoDavid::diag_once(const HPsiFunc& hpsi_func, setmem_complex_op()(psi_in, 0, nband * ld_psi); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - basis, // A dim * nbase - dim, - this->vcc, // B nbase * nband - nbase_x, - this->zero, - psi_in, // C dim * nband - ld_psi - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + basis, // A dim * nbase + dim, + this->vcc, // B nbase * nband + nbase_x, + this->zero, + psi_in, // C dim * nband + ld_psi); if (!this->notconv || (dav_iter == david_maxiter)) { @@ -378,20 +377,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // basis[nbase] = hpsi * vc_ev_vector = hpsi*vcc // basis' = vc_ev_vector' * hpsi' // (dim, notconv) (dim, nbase) (nbase, notconv) - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - hpsi, // A dim * nbase - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - vc_ev_vector, // B nbase * notconv - nbase, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - basis + dim*nbase, // C dim * notconv - dim // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + hpsi, // A dim * nbase + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + vc_ev_vector, // B nbase * notconv + nbase, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + basis + dim * nbase, // C dim * notconv + dim // LDC: if(N) max(1, m) ); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -417,21 +416,21 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, Real* e_temp_gpu = nullptr; resmem_var_op()(e_temp_gpu, nbase); syncmem_var_h2d_op()(e_temp_gpu, e_temp_cpu.data(), nbase); - vector_mul_vector_op()(this->ctx, - nbase, - vc_ev_vector + m * nbase, - vc_ev_vector + m * nbase, - e_temp_gpu); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vc_ev_vector + m * nbase, + vc_ev_vector + m * nbase, + e_temp_gpu); delmem_var_op()(e_temp_gpu); #endif } else { - vector_mul_vector_op()(this->ctx, - nbase, - vc_ev_vector + m * nbase, - vc_ev_vector + m * nbase, - e_temp_cpu.data()); + ModuleBase::vector_mul_vector_op()(this->ctx, + nbase, + vc_ev_vector + m * nbase, + vc_ev_vector + m * nbase, + e_temp_cpu.data()); } } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -441,20 +440,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // = (H - lambda * S) * psi * vcc // = (H - lambda * S) * psi_new //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - notconv, // n: col of B,C - nbase, // k: col of A, row of B - this->one, // alpha - spsi, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - vc_ev_vector, // B - nbase, // LDB: if(N) max(1,k) if(T) max(1,n) - this->one, // belta - basis + dim*nbase, // C dim * notconv - dim // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + notconv, // n: col of B,C + nbase, // k: col of A, row of B + this->one, // alpha + spsi, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + vc_ev_vector, // B + nbase, // LDB: if(N) max(1,k) if(T) max(1,n) + this->one, // belta + basis + dim * nbase, // C dim * notconv + dim // LDC: if(N) max(1, m) ); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< @@ -469,20 +468,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, if (this->device == base_device::GpuDevice) { #if defined(__CUDA) || defined(__ROCM) - vector_div_vector_op()(this->ctx, - dim, - basis + dim*(nbase + m), - basis + dim*(nbase + m), - this->d_precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + dim, + basis + dim * (nbase + m), + basis + dim * (nbase + m), + this->d_precondition); #endif } else { - vector_div_vector_op()(this->ctx, - dim, - basis + dim*(nbase + m), - basis + dim*(nbase + m), - this->precondition); + ModuleBase::vector_div_vector_op()(this->ctx, + dim, + basis + dim * (nbase + m), + basis + dim * (nbase + m), + this->precondition); } //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // for (int ig = 0; ig < dim; ig++) @@ -519,20 +518,20 @@ void DiagoDavid::cal_grad(const HPsiFunc& hpsi_func, // first nbase bands psi* dot notconv bands spsi to prepare lagrange_matrix // calculate the square matrix for future lagranges - gemm_op()(this->ctx, - 'C', - 'N', - nbase, // m: row of A,C - notconv, // n: col of B,C - dim, // k: col of A, row of B - this->one, // alpha - basis, // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - &spsi[nbase * dim], // B - dim, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - lagrange, // C - nbase + notconv // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + nbase, // m: row of A,C + notconv, // n: col of B,C + dim, // k: col of A, row of B + this->one, // alpha + basis, // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + &spsi[nbase * dim], // B + dim, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + lagrange, // C + nbase + notconv // LDC: if(N) max(1, m) ); for (int m = 0; m < notconv; m++) @@ -593,20 +592,20 @@ void DiagoDavid::cal_elem(const int& dim, ModuleBase::timer::tick("DiagoDavid", "cal_elem"); // hcc[nbase](notconv, nbase + notconv)= basis[nbase]' * hpsi - gemm_op()(this->ctx, - 'C', - 'N', - notconv, - nbase + notconv, - dim, - this->one, - basis + dim*nbase, // basis(:,nbase:) dim * notconv - dim, - hpsi, // dim * (nbase + notconv) - dim, - this->zero, - hcc + nbase, // notconv * (nbase + notconv) - nbase_x); + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + notconv, + nbase + notconv, + dim, + this->one, + basis + dim * nbase, // basis(:,nbase:) dim * notconv + dim, + hpsi, // dim * (nbase + notconv) + dim, + this->zero, + hcc + nbase, // notconv * (nbase + notconv) + nbase_x); // scc[nbase] = basis[nbase]' * spsi // gemm_op()(this->ctx, // 'C', @@ -627,7 +626,7 @@ void DiagoDavid::cal_elem(const int& dim, #ifdef __MPI if (diag_comm.nproc > 1) { - matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); + ModuleBase::matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); // matrixTranspose_op()(this->ctx, nbase_x, nbase_x, scc, scc); auto* swap = new T[notconv * nbase_x]; @@ -657,7 +656,7 @@ void DiagoDavid::cal_elem(const int& dim, // Parallel_Reduce::reduce_complex_double_pool( hcc + nbase * nbase_x, notconv * nbase_x ); // Parallel_Reduce::reduce_complex_double_pool( scc + nbase * nbase_x, notconv * nbase_x ); - matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); + ModuleBase::matrixTranspose_op()(this->ctx, nbase_x, nbase_x, hcc, hcc); // matrixTranspose_op()(this->ctx, nbase_x, nbase_x, scc, scc); } #endif @@ -751,39 +750,37 @@ void DiagoDavid::refresh(const int& dim, setmem_complex_op()(basis , 0, nbase_x * dim); // basis(dim, nband) = hpsi(dim, nbase) * vcc(nbase, nband) - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - hpsi, // A dim * nbase - dim, - vcc, // B nbase * nband - nbase_x, - zero, - basis, // C dim * nband - dim - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + hpsi, // A dim * nbase + dim, + vcc, // B nbase * nband + nbase_x, + zero, + basis, // C dim * nband + dim); //<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< // basis[nband] = spsi * vcc - gemm_op()(this->ctx, - 'N', - 'N', - dim, // m: row of A,C - nband, // n: col of B,C - nbase, // k: col of A, row of B - this->one, - spsi, // A dim * nbase - dim, - vcc, // B nbase * nband - nbase_x, - this->zero, - basis + dim*nband, // C dim * nband - dim - ); + ModuleBase::gemm_op()(this->ctx, + 'N', + 'N', + dim, // m: row of A,C + nband, // n: col of B,C + nbase, // k: col of A, row of B + this->one, + spsi, // A dim * nbase + dim, + vcc, // B nbase * nband + nbase_x, + this->zero, + basis + dim * nband, // C dim * nband + dim); // hpsi = basis, spsi = basis[nband] syncmem_complex_op()(hpsi, basis, dim * nband); @@ -900,37 +897,37 @@ void DiagoDavid::SchmidtOrth(const int& dim, { // lagrange_m[m - mv_size + 1 - mm_size] // = basis[m - mv_size + 1 - mm_size]' * spsi[m] - gemm_op()(this->ctx, - 'C', - 'N', - mm_size, // m: row of A,C - mm_size, // n: col of B,C - dim, // k: col of A, row of B - this->one, // alpha - basis + dim*(m - mv_size + 1 - mm_size), // A - dim, // LDA: if(N) max(1,m) if(T) max(1,k) - &spsi[m * dim], // B - dim, // LDB: if(N) max(1,k) if(T) max(1,n) - this->zero, // belta - &lagrange_m[m - mv_size + 1 - mm_size], // C - nband // LDC: if(N) max(1, m) + ModuleBase::gemm_op()(this->ctx, + 'C', + 'N', + mm_size, // m: row of A,C + mm_size, // n: col of B,C + dim, // k: col of A, row of B + this->one, // alpha + basis + dim * (m - mv_size + 1 - mm_size), // A + dim, // LDA: if(N) max(1,m) if(T) max(1,k) + &spsi[m * dim], // B + dim, // LDB: if(N) max(1,k) if(T) max(1,n) + this->zero, // belta + &lagrange_m[m - mv_size + 1 - mm_size], // C + nband // LDC: if(N) max(1, m) ); } // calculate other lagranges for this band // lagrange_m[m - mv_size + 1] // = basis[m - mv_size + 1]' * spsi[m] - gemv_op()(this->ctx, - 'C', - dim, - mv_size, - this->one, - basis + dim*(m - mv_size + 1), - dim, - &spsi[m * dim], - 1, - this->zero, - &lagrange_m[m - mv_size + 1], - 1); + ModuleBase::gemv_op()(this->ctx, + 'C', + dim, + mv_size, + this->one, + basis + dim * (m - mv_size + 1), + dim, + &spsi[m * dim], + 1, + this->zero, + &lagrange_m[m - mv_size + 1], + 1); Parallel_Reduce::reduce_pool(lagrange_m, m + 1); @@ -942,21 +939,21 @@ void DiagoDavid::SchmidtOrth(const int& dim, // / psi_m = psi_m - \sum_{i < m} \langle psi(i)|S|psi(m) \rangle psi(i) // psi_m = psi_m - basis * lagrange_m - gemv_op()(this->ctx, - 'N', - dim, - m, - this->neg_one, - basis, - dim, - lagrange_m, - 1, - this->one, - psi_m, - 1); + ModuleBase::gemv_op()(this->ctx, + 'N', + dim, + m, + this->neg_one, + basis, + dim, + lagrange_m, + 1, + this->one, + psi_m, + 1); // psi_norm = psi_norm - lagrange_m ยท lagrange_m - psi_norm -= dot_real_op()(this->ctx, m, lagrange_m, lagrange_m, false); + psi_norm -= ModuleBase::dot_real_op()(this->ctx, m, lagrange_m, lagrange_m, false); // for (int j = 0; j < m; j++) // { @@ -983,7 +980,7 @@ void DiagoDavid::SchmidtOrth(const int& dim, else { // psi_m = psi_m / psi_norm - vector_div_constant_op()(this->ctx, dim, psi_m, psi_m, psi_norm); + ModuleBase::vector_div_constant_op()(this->ctx, dim, psi_m, psi_m, psi_norm); // for (int i = 0; i < npw; i++) // { // psi_m[i] /= psi_norm; diff --git a/source/module_hsolver/diago_iter_assist.cpp b/source/module_hsolver/diago_iter_assist.cpp index 5a3acf8e53..ea1f36d900 100644 --- a/source/module_hsolver/diago_iter_assist.cpp +++ b/source/module_hsolver/diago_iter_assist.cpp @@ -9,7 +9,7 @@ #include "module_base/parallel_reduce.h" #include "module_base/timer.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" namespace hsolver { @@ -73,39 +73,39 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* hpsi_info hpsi_in(&psi, all_bands_range, hphi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - hphi, - dmax, - &zero, - hcc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + hphi, + dmax, + &zero, + hcc, + nstart); T* sphi = temp; // do sPsi for all bands pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - sphi, - dmax, - &zero, - scc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + sphi, + dmax, + &zero, + scc, + nstart); } if (GlobalV::NPROC_IN_POOL > 1) @@ -121,25 +121,25 @@ void DiagoIterAssist::diagH_subspace(const hamilt::Hamilt* const int ld_temp = in_place ? dmax : dmin; { // code block to calculate evc - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - psi.get_pointer(), // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - temp, - ld_temp); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + psi.get_pointer(), // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + temp, + ld_temp); } if (!in_place) { - matrixSetToAnother()(ctx, n_band, temp, ld_temp, evc.get_pointer(), dmax); + ModuleBase::matrixSetToAnother()(ctx, n_band, temp, ld_temp, evc.get_pointer(), dmax); delmem_complex_op()(temp); } delmem_complex_op()(hcc); @@ -222,7 +222,7 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* pHamilt->ops->hPsi(hpsi_in); // calculate the related elements in hcc - gemv_op()(ctx, 'C', psi_nc, nstart, &one, psi, psi_nc, hpsi, 1, &zero, hcc + i * nstart, 1); + ModuleBase::gemv_op()(ctx, 'C', psi_nc, nstart, &one, psi, psi_nc, hpsi, 1, &zero, hcc + i * nstart, 1); } T* spsi = temp; @@ -232,18 +232,18 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* syncmem_complex_op()(ppsi, psi + i * psi_nc, psi_nc); pHamilt->sPsi(ppsi, spsi, dmin, dmin, 1); - gemv_op()(ctx, - 'C', - psi_nc, - nstart, - &one, - psi, - psi_nc, // nbasis - spsi, - 1, - &zero, - scc + i * nstart, - 1); + ModuleBase::gemv_op()(ctx, + 'C', + psi_nc, + nstart, + &one, + psi, + psi_nc, // nbasis + spsi, + 1, + &zero, + scc + i * nstart, + 1); } delmem_complex_op()(temp); } @@ -264,13 +264,13 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* hpsi_info hpsi_in(&psi_temp, all_bands_range, hpsi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, hpsi, dmax, &zero, hcc, nstart); + ModuleBase::gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, hpsi, dmax, &zero, hcc, nstart); T* spsi = temp; // do sPsi for all bands pHamilt->sPsi(ppsi, spsi, psi_temp.get_nbasis(), psi_temp.get_nbasis(), psi_temp.get_nbands()); - gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); + ModuleBase::gemm_op()(ctx, 'C', 'N', nstart, nstart, dmin, &one, ppsi, dmax, spsi, dmax, &zero, scc, nstart); delmem_complex_op()(temp); add_to_hcc(hcc, nstart); @@ -315,20 +315,20 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* // because psi and evc are different here, // I think if psi and evc are the same, // there may be problems, mohan 2011-01-01 - gemm_op()(ctx, - 'N', - 'N', - dmax, - n_band, - nstart, - &one, - psi, // dmax * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - evc.get_pointer(), - dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmax, + n_band, + nstart, + &one, + psi, // dmax * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + evc.get_pointer(), + dmax); } else { @@ -338,20 +338,20 @@ void DiagoIterAssist::diagH_subspace_init(hamilt::Hamilt* // resmem_complex_op()(ctx, evctemp, n_band * dmin, "DiagSub::evctemp"); // setmem_complex_op()(ctx, evctemp, 0, n_band * dmin); - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - psi, // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - evc.get_pointer(), - dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + psi, // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + evc.get_pointer(), + dmax); // matrixSetToAnother()(ctx, n_band, evctemp, dmin, evc.get_pointer(), dmax); @@ -442,39 +442,39 @@ void DiagoIterAssist::cal_hs_subspace(const hamilt::Hamilt hpsi_info hpsi_in(&psi, all_bands_range, hphi); pHamilt->ops->hPsi(hpsi_in); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - hphi, - dmax, - &zero, - hcc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + hphi, + dmax, + &zero, + hcc, + nstart); T* sphi = temp; // do sPsi for all bands pHamilt->sPsi(psi.get_pointer(), sphi, dmax, dmin, nstart); - gemm_op()(ctx, - 'C', - 'N', - nstart, - nstart, - dmin, - &one, - psi.get_pointer(), - dmax, - sphi, - dmax, - &zero, - scc, - nstart); + ModuleBase::gemm_op()(ctx, + 'C', + 'N', + nstart, + nstart, + dmin, + &one, + psi.get_pointer(), + dmax, + sphi, + dmax, + &zero, + scc, + nstart); } if (GlobalV::NPROC_IN_POOL > 1) @@ -509,20 +509,20 @@ void DiagoIterAssist::diag_responce( const T* hcc, DiagoIterAssist::diagH_LAPACK(nstart, nstart, hcc, scc, nstart, en, vcc); { // code block to calculate tar_mat - gemm_op()(ctx, - 'N', - 'N', - mat_col, - nstart, - nstart, - &one, - mat_in, // mat_col * nstart - mat_col, - vcc, // nstart * nstart - nstart, - &zero, - mat_out, - mat_col); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + mat_col, + nstart, + nstart, + &one, + mat_in, // mat_col * nstart + mat_col, + vcc, // nstart * nstart + nstart, + &zero, + mat_out, + mat_col); } delmem_complex_op()(vcc); @@ -557,21 +557,21 @@ void DiagoIterAssist::diag_subspace_psi(const T* hcc, T* temp = nullptr; resmem_complex_op()(temp, nstart * dmax, "DiagSub::temp"); setmem_complex_op()(temp, 0, nstart * dmax); - gemm_op()(ctx, - 'N', - 'N', - dmin, - n_band, - nstart, - &one, - evc.get_pointer(), // dmin * nstart - dmax, - vcc, // nstart * n_band - nstart, - &zero, - temp, - dmin); - matrixSetToAnother()(ctx, n_band, temp, dmin, evc.get_pointer(), dmax); + ModuleBase::gemm_op()(ctx, + 'N', + 'N', + dmin, + n_band, + nstart, + &one, + evc.get_pointer(), // dmin * nstart + dmax, + vcc, // nstart * n_band + nstart, + &zero, + temp, + dmin); + ModuleBase::matrixSetToAnother()(ctx, n_band, temp, dmin, evc.get_pointer(), dmax); delmem_complex_op()(temp); } diff --git a/source/module_hsolver/hsolver_pw_sdft.cpp b/source/module_hsolver/hsolver_pw_sdft.cpp index 68075fc111..d03b37b848 100644 --- a/source/module_hsolver/hsolver_pw_sdft.cpp +++ b/source/module_hsolver/hsolver_pw_sdft.cpp @@ -60,7 +60,7 @@ void HSolverPW_SDFT::solve(const UnitCell& ucell, #ifdef __MPI if (nbands > 0 && PARAM.inp.bndpar > 1) { - Parallel_Common::bcast_dev(this->ctx, &psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); + Parallel_Common::bcast_dev(&psi(ik, 0, 0), npwx * nbands, PARAPW_WORLD, &psi_cpu(ik, 0, 0)); MPI_Bcast(&pes->ekb(ik, 0), nbands, MPI_DOUBLE, 0, PARAPW_WORLD); } #endif diff --git a/source/module_hsolver/kernels/test/CMakeLists.txt b/source/module_hsolver/kernels/test/CMakeLists.txt index c8d1f2cdd9..5fe6bf4a24 100644 --- a/source/module_hsolver/kernels/test/CMakeLists.txt +++ b/source/module_hsolver/kernels/test/CMakeLists.txt @@ -5,13 +5,7 @@ if(USE_CUDA OR USE_ROCM) AddTest( TARGET Hsolver_Kernels_UTs LIBS parameter ${math_libs} base device - SOURCES math_kernel_test.cpp math_dngvd_test.cpp - ) -elseif() - AddTest( - TARGET Hsolver_Kernels_UTs - LIBS parameter ${math_libs} base device - SOURCES math_kernel_test.cpp ../../../module_base/blas_connector.cpp + SOURCES math_dngvd_test.cpp ) endif() diff --git a/source/module_hsolver/kernels/test/math_dngvd_test.cpp b/source/module_hsolver/kernels/test/math_dngvd_test.cpp index a67b18d4be..d8f2376890 100644 --- a/source/module_hsolver/kernels/test/math_dngvd_test.cpp +++ b/source/module_hsolver/kernels/test/math_dngvd_test.cpp @@ -2,7 +2,7 @@ #include "module_base/lapack_connector.h" #include "module_base/module_device/memory_op.h" #include "module_hsolver/kernels/dngvd_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -144,13 +144,13 @@ TEST_F(TestModuleHsolverMathDngvd, transpose_gpu) synchronize_memory_op_C2G_Z()(device_transpose, transpose.data(), transpose.size()); // run - hsolver::createGpuBlasHandle(); - hsolver::matrixTranspose_op, base_device::DEVICE_GPU>()(gpu_ctx, + ModuleBase::createGpuBlasHandle(); + ModuleBase::matrixTranspose_op, base_device::DEVICE_GPU>()(gpu_ctx, 2, 3, device_transpose, device_transpose); - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); // copy transpose data from GPU to CPU std::vector> transpose_result = { diff --git a/source/module_hsolver/kernels/test/perf_math_kernel.cpp b/source/module_hsolver/kernels/test/perf_math_kernel.cpp index b2b0704a9d..e0a955ccb5 100644 --- a/source/module_hsolver/kernels/test/perf_math_kernel.cpp +++ b/source/module_hsolver/kernels/test/perf_math_kernel.cpp @@ -1,7 +1,7 @@ #include "module_base/blas_connector.h" #include "module_base/constants.h" #include "module_base/module_device/memory_op.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include #include @@ -114,7 +114,7 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture { resize_memory_op_double()(test_dvector_a_gpu, dim_vector); synchronize_memory_op_double()(test_dvector_a_gpu, test_dvector_a, dim_vector); - hsolver::createGpuBlasHandle(); + ModuleBase::createGpuBlasHandle(); #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM @@ -125,36 +125,36 @@ class PerfModuleHsolverMathKernel : public benchmark::Fixture { delete[] result_zvector; delete[] test_dvector_a; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM - hsolver::destoryBLAShandle(); + ModuleBase::destoryBLAShandle(); #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM } // OPs need benchmark // CPU operator - using zdot_real_cpu_op = hsolver::dot_real_op, base_device::DEVICE_CPU>; + using zdot_real_cpu_op = ModuleBase::dot_real_op, base_device::DEVICE_CPU>; - using vector_div_constant_op_cpu = hsolver::vector_div_constant_op, base_device::DEVICE_CPU>; - using vector_mul_vector_op_cpu = hsolver::vector_mul_vector_op, base_device::DEVICE_CPU>; - using vector_div_vector_op_cpu = hsolver::vector_div_vector_op, base_device::DEVICE_CPU>; + using vector_div_constant_op_cpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_CPU>; + using vector_mul_vector_op_cpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_CPU>; + using vector_div_vector_op_cpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_CPU>; using constantvector_addORsub_constantVector_op_cpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; - using axpy_op_cpu = hsolver::axpy_op, base_device::DEVICE_CPU>; - using scal_op_cpu = hsolver::scal_op; - using gemv_op_cpu = hsolver::gemv_op, base_device::DEVICE_CPU>; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_CPU>; + using axpy_op_cpu = ModuleBase::axpy_op, base_device::DEVICE_CPU>; + using scal_op_cpu = ModuleBase::scal_op; + using gemv_op_cpu = ModuleBase::gemv_op, base_device::DEVICE_CPU>; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM // GPU operator - using zdot_real_gpu_op = hsolver::dot_real_op, base_device::DEVICE_GPU>; + using zdot_real_gpu_op = ModuleBase::dot_real_op, base_device::DEVICE_GPU>; - using vector_div_constant_op_gpu = hsolver::vector_div_constant_op, base_device::DEVICE_GPU>; - using vector_mul_vector_op_gpu = hsolver::vector_mul_vector_op, base_device::DEVICE_GPU>; - using vector_div_vector_op_gpu = hsolver::vector_div_vector_op, base_device::DEVICE_GPU>; + using vector_div_constant_op_gpu = ModuleBase::vector_div_constant_op, base_device::DEVICE_GPU>; + using vector_mul_vector_op_gpu = ModuleBase::vector_mul_vector_op, base_device::DEVICE_GPU>; + using vector_div_vector_op_gpu = ModuleBase::vector_div_vector_op, base_device::DEVICE_GPU>; using constantvector_addORsub_constantVector_op_gpu - = hsolver::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; - using axpy_op_gpu = hsolver::axpy_op, base_device::DEVICE_GPU>; - using scal_op_gpu = hsolver::scal_op; + = ModuleBase::constantvector_addORsub_constantVector_op, base_device::DEVICE_GPU>; + using axpy_op_gpu = ModuleBase::axpy_op, base_device::DEVICE_GPU>; + using scal_op_gpu = ModuleBase::scal_op; #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM }; diff --git a/source/module_hsolver/test/CMakeLists.txt b/source/module_hsolver/test/CMakeLists.txt index fdb447a09d..e44171912c 100644 --- a/source/module_hsolver/test/CMakeLists.txt +++ b/source/module_hsolver/test/CMakeLists.txt @@ -114,7 +114,6 @@ if (ENABLE_MPI) TARGET HSolver_LCAO_cusolver LIBS parameter ${math_libs} base psi device SOURCES diago_lcao_cusolver_test.cpp ../diago_cusolver.cpp ../diago_scalapack.cpp - ../kernels/math_kernel_op.cpp ../kernels/dngvd_op.cpp ../kernels/cuda/diag_cusolver.cu ) diff --git a/source/module_hsolver/test/diago_bpcg_test.cpp b/source/module_hsolver/test/diago_bpcg_test.cpp index 8978334106..e6af8b5b5e 100644 --- a/source/module_hsolver/test/diago_bpcg_test.cpp +++ b/source/module_hsolver/test/diago_bpcg_test.cpp @@ -144,7 +144,7 @@ class DiagoBPCGPrepare base_device::DEVICE_CPU *ctx = {}; // hpsi_out(dim * nvec) = h_mat(dim * dim) * psi_in(dim * nvec) - hsolver::gemm_op()( + ModuleBase::gemm_op()( ctx, 'N', 'N', dim, nvec, dim, one_, diff --git a/source/module_lr/operator_casida/operator_lr_diag.h b/source/module_lr/operator_casida/operator_lr_diag.h index 99a61d90df..a739b81991 100644 --- a/source/module_lr/operator_casida/operator_lr_diag.h +++ b/source/module_lr/operator_casida/operator_lr_diag.h @@ -1,6 +1,6 @@ #pragma once #include "module_lr/utils/lr_util.h" -#include "module_hsolver/kernels/math_kernel_op.h" +#include "module_base/kernels/math_kernel_op.h" #include "module_hamilt_general/operator.h" #ifdef __MPI #include "module_base/parallel_common.h" @@ -46,7 +46,7 @@ namespace LR const bool is_first_node = false)const override { ModuleBase::TITLE("OperatorLRDiag", "act"); - hsolver::vector_mul_vector_op()(this->ctx, + ModuleBase::vector_mul_vector_op()(this->ctx, nk * pX.get_local_size(), // local size of particle-hole basis hpsi, psi_in,