diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 69f51b744f..85ea4584e9 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -82,6 +82,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d } // C = a * A.? * B.? + b * C +// Row-Major part void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) @@ -154,6 +155,147 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons #endif } +// Col-Major part +void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + sgemm_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + sgemm_mth_(&transb, &transa, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc, GlobalV::MY_RANK); + } + #endif +} + +void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dgemm_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice){ + dgemm_mth_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc, GlobalV::MY_RANK); + } + #endif +} + +void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + cgemm_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + cgemm_mth_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc, GlobalV::MY_RANK); + } + #endif +} + +void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zgemm_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } + #ifdef __DSP + else if (device_type == base_device::AbacusDevice_t::DspDevice) { + zgemm_mth_(&transa, &transb, &m, &n, &k, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc, GlobalV::MY_RANK); + } + #endif +} + +// Symm and Hemm part. Only col-major is supported. + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + ssymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + dsymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + csymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + +void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zsymm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + +void BlasConnector::hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + chemm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + +void BlasConnector::hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type) +{ + if (device_type == base_device::AbacusDevice_t::CpuDevice) { + zhemm_(&side, &uplo, &m, &n, + &alpha, a, &lda, b, &ldb, + &beta, c, &ldc); + } +} + void BlasConnector::gemv(const char trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* X, const int incx, const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type) @@ -190,7 +332,6 @@ void BlasConnector::gemv(const char trans, const int m, const int n, } } - // out = ||x||_2 float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type ) { diff --git a/source/module_base/blas_connector.h b/source/module_base/blas_connector.h index 090d8512ee..7675429520 100644 --- a/source/module_base/blas_connector.h +++ b/source/module_base/blas_connector.h @@ -111,11 +111,23 @@ extern "C" const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, const std::complex *beta, std::complex *c, const int *ldc); - //a is symmetric + // A is symmetric. C = a * A.? * B.? + b * C + void ssymm_(const char *side, const char *uplo, const int *m, const int *n, + const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, + const float *beta, float *c, const int *ldc); void dsymm_(const char *side, const char *uplo, const int *m, const int *n, const double *alpha, const double *a, const int *lda, const double *b, const int *ldb, const double *beta, double *c, const int *ldc); - //a is hermitian + void csymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + void zsymm_(const char *side, const char *uplo, const int *m, const int *n, + const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, + const std::complex *beta, std::complex *c, const int *ldc); + + // A is hermitian. C = a * A.? * B.? + b * C + void chemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, + std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); void zhemm_(char *side, char *uplo, int *m, int *n,std::complex *alpha, std::complex *a, int *lda, std::complex *b, int *ldb, std::complex *beta, std::complex *c, int *ldc); @@ -175,6 +187,7 @@ class BlasConnector // Peize Lin add 2017-10-27, fix bug trans 2019-01-17 // C = a * A.? * B.? + b * C + // Row Major by default static void gemm(const char transa, const char transb, const int m, const int n, const int k, const float alpha, const float *a, const int lda, const float *b, const int ldb, @@ -195,6 +208,61 @@ class BlasConnector const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + // Col-Major if you need to use it + + static + void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void gemm_cm(const char transa, const char transb, const int m, const int n, const int k, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + // Because you cannot pack symm or hemm into a row-major kernel by exchanging parameters, so only col-major functions are provided. + static + void symm_cm(const char side, const char uplo, const int m, const int n, + const float alpha, const float *a, const int lda, const float *b, const int ldb, + const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void symm_cm(const char side, const char uplo, const int m, const int n, + const double alpha, const double *a, const int lda, const double *b, const int ldb, + const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void symm_cm(const char side, const char uplo, const int m, const int n, + const std::complex alpha, const std::complex *a, const int lda, const std::complex *b, const int ldb, + const std::complex beta, std::complex *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + static + void hemm_cm(char side, char uplo, int m, int n, + std::complex alpha, std::complex *a, int lda, std::complex *b, int ldb, + std::complex beta, std::complex *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + // y = A*x + beta*y + static void gemv(const char trans, const int m, const int n, const float alpha, const float* A, const int lda, const float* X, const int incx, @@ -234,6 +302,8 @@ class BlasConnector static void copy(const long n, const std::complex *a, const int incx, std::complex *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice); + + // A is symmetric }; // If GATHER_INFO is defined, the original function is replaced with a "i" suffix, diff --git a/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp b/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp index fabe3b0773..efa51ec4a6 100644 --- a/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp +++ b/source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp @@ -1,6 +1,7 @@ #include "gint_tools.h" #include "module_base/timer.h" #include "module_base/ylm.h" +#include "module_base/blas_connector.h" namespace Gint_Tools{ @@ -60,8 +61,8 @@ void mult_psi_DMR( const auto tmp_matrix_ptr = tmp_matrix->get_pointer(); const int idx1 = block_index[ia1]; - dsymm_(&side, &uplo, &block_size[ia1], &ib_len, &alpha, tmp_matrix_ptr, &block_size[ia1], - &psi[ib_start][idx1], &LD_pool, &beta, &psi_DMR[ib_start][idx1], &LD_pool); + BlasConnector::symm_cm(side, uplo, block_size[ia1], ib_len, alpha, tmp_matrix_ptr, block_size[ia1], + &psi[ib_start][idx1], LD_pool, beta, &psi_DMR[ib_start][idx1], LD_pool); } //! get (j,beta,R2)