Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ double BlasConnector::dot( const int n, const double *X, const int incX, const d
}

// C = a * A.? * B.? + b * C
// Row-Major part
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
Expand Down Expand Up @@ -154,6 +155,147 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
#endif
}

// Col-Major part
void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mth_(&transb, &transa, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
cgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

void BlasConnector::gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zgemm_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mth_(&transa, &transb, &m, &n, &k,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}

// Symm and Hemm part. Only col-major is supported.

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
ssymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
dsymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
csymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zsymm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
chemm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::hemm_cm(char side, char uplo, int m, int n,
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
zhemm_(&side, &uplo, &m, &n,
&alpha, a, &lda, b, &ldb,
&beta, c, &ldc);
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
Expand Down Expand Up @@ -190,7 +332,6 @@ void BlasConnector::gemv(const char trans, const int m, const int n,
}
}


// out = ||x||_2
float BlasConnector::nrm2( const int n, const float *X, const int incX, base_device::AbacusDevice_t device_type )
{
Expand Down
74 changes: 72 additions & 2 deletions source/module_base/blas_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,23 @@ extern "C"
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);

//a is symmetric
// A is symmetric. C = a * A.? * B.? + b * C
void ssymm_(const char *side, const char *uplo, const int *m, const int *n,
const float *alpha, const float *a, const int *lda, const float *b, const int *ldb,
const float *beta, float *c, const int *ldc);
void dsymm_(const char *side, const char *uplo, const int *m, const int *n,
const double *alpha, const double *a, const int *lda, const double *b, const int *ldb,
const double *beta, double *c, const int *ldc);
//a is hermitian
void csymm_(const char *side, const char *uplo, const int *m, const int *n,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda, const std::complex<float> *b, const int *ldb,
const std::complex<float> *beta, std::complex<float> *c, const int *ldc);
void zsymm_(const char *side, const char *uplo, const int *m, const int *n,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda, const std::complex<double> *b, const int *ldb,
const std::complex<double> *beta, std::complex<double> *c, const int *ldc);

// A is hermitian. C = a * A.? * B.? + b * C
void chemm_(char *side, char *uplo, int *m, int *n,std::complex<float> *alpha,
std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, std::complex<float> *beta, std::complex<float> *c, int *ldc);
void zhemm_(char *side, char *uplo, int *m, int *n,std::complex<double> *alpha,
std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, std::complex<double> *beta, std::complex<double> *c, int *ldc);

Expand Down Expand Up @@ -175,6 +187,7 @@ class BlasConnector

// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
// C = a * A.? * B.? + b * C
// Row Major by default
static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
Expand All @@ -195,6 +208,61 @@ class BlasConnector
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// Col-Major if you need to use it

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemm_cm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// Because you cannot pack symm or hemm into a row-major kernel by exchanging parameters, so only col-major functions are provided.
static
void symm_cm(const char side, const char uplo, const int m, const int n,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void symm_cm(const char side, const char uplo, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void hemm_cm(char side, char uplo, int m, int n,
std::complex<float> alpha, std::complex<float> *a, int lda, std::complex<float> *b, int ldb,
std::complex<float> beta, std::complex<float> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void hemm_cm(char side, char uplo, int m, int n,
std::complex<double> alpha, std::complex<double> *a, int lda, std::complex<double> *b, int ldb,
std::complex<double> beta, std::complex<double> *c, int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// y = A*x + beta*y

static
void gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
Expand Down Expand Up @@ -234,6 +302,8 @@ class BlasConnector

static
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

// A is symmetric
};

// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,
Expand Down
5 changes: 3 additions & 2 deletions source/module_hamilt_lcao/module_gint/mult_psi_dmr.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "gint_tools.h"
#include "module_base/timer.h"
#include "module_base/ylm.h"
#include "module_base/blas_connector.h"

namespace Gint_Tools{

Expand Down Expand Up @@ -60,8 +61,8 @@ void mult_psi_DMR(

const auto tmp_matrix_ptr = tmp_matrix->get_pointer();
const int idx1 = block_index[ia1];
dsymm_(&side, &uplo, &block_size[ia1], &ib_len, &alpha, tmp_matrix_ptr, &block_size[ia1],
&psi[ib_start][idx1], &LD_pool, &beta, &psi_DMR[ib_start][idx1], &LD_pool);
BlasConnector::symm_cm(side, uplo, block_size[ia1], ib_len, alpha, tmp_matrix_ptr, block_size[ia1],
&psi[ib_start][idx1], LD_pool, beta, &psi_DMR[ib_start][idx1], LD_pool);
}

//! get (j,beta,R2)
Expand Down
Loading