Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ OBJS_MAIN=main.o\

OBJS_BASE=abfs-vector3_order.o\
assoc_laguerre.o\
blas_connector.o\
complexarray.o\
complexmatrix.o\
clebsch_gordan_coeff.o\
Expand Down
1 change: 1 addition & 0 deletions source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_library(
base
OBJECT
assoc_laguerre.cpp
blas_connector.cpp
clebsch_gordan_coeff.cpp
complexarray.cpp
complexmatrix.cpp
Expand Down
143 changes: 143 additions & 0 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#include "blas_connector.h"

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY)
{
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}

void BlasConnector::axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY)
{
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}

void BlasConnector::axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY)
{
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}

void BlasConnector::axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY)
{
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}


// x=a*x
void BlasConnector::scal( const int n, const float alpha, float *X, const int incX)
{
sscal_(&n, &alpha, X, &incX);
}

void BlasConnector::scal( const int n, const double alpha, double *X, const int incX)
{
dscal_(&n, &alpha, X, &incX);
}

void BlasConnector::scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX)
{
cscal_(&n, &alpha, X, &incX);
}

void BlasConnector::scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX)
{
zscal_(&n, &alpha, X, &incX);
}


// d=x*y
float BlasConnector::dot( const int n, const float *X, const int incX, const float *Y, const int incY)
{
return sdot_(&n, X, &incX, Y, &incY);
}

double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY)
{
return ddot_(&n, X, &incX, Y, &incY);
}

// C = a * A.? * B.? + b * C
void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc)
{
sgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc)
{
dgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc)
{
cgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc)
{
zgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const double alpha, const double* A, const int lda, const double* X, const int incx,
const double beta, double* Y, const int incy)
{
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incx,
const std::complex<float> beta, std::complex<float> *Y, const int incy)
{
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incx,
const std::complex<double> beta, std::complex<double> *Y, const int incy)
{
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}


// out = ||x||_2
float BlasConnector::nrm2( const int n, const float *X, const int incX )
{
return snrm2_( &n, X, &incX );
}


double BlasConnector::nrm2( const int n, const double *X, const int incX )
{
return dnrm2_( &n, X, &incX );
}


double BlasConnector::nrm2( const int n, const std::complex<double> *X, const int incX )
{
return dznrm2_( &n, X, &incX );
}

// copies a into b
void BlasConnector::copy(const long n, const double *a, const int incx, double *b, const int incy)
{
dcopy_(&n, a, &incx, b, &incy);
}

void BlasConnector::copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
{
zcopy_(&n, a, &incx, b, &incy);
}
192 changes: 71 additions & 121 deletions source/module_base/blas_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#include <complex>

// These still need to be linked in the header file
// Because quite a lot of code will directly use the original cblas kernels.

extern "C"
{
// level 1: std::vector-std::vector operations, O(n) data and O(n) work.
Expand Down Expand Up @@ -115,152 +118,99 @@ class BlasConnector

// Peize Lin add 2016-08-04
// y=a*x+y
static inline
void axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY)
{
saxpy_(&n, &alpha, X, &incX, Y, &incY);
}
static inline
void axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY)
{
daxpy_(&n, &alpha, X, &incX, Y, &incY);
}
static inline
void axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY)
{
caxpy_(&n, &alpha, X, &incX, Y, &incY);
}
static inline
void axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY)
{
zaxpy_(&n, &alpha, X, &incX, Y, &incY);
}
static
void axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY);

static
void axpy( const int n, const double alpha, const double *X, const int incX, double *Y, const int incY);

static
void axpy( const int n, const std::complex<float> alpha, const std::complex<float> *X, const int incX, std::complex<float> *Y, const int incY);

static
void axpy( const int n, const std::complex<double> alpha, const std::complex<double> *X, const int incX, std::complex<double> *Y, const int incY);


// Peize Lin add 2016-08-04
// x=a*x
static inline
void scal( const int n, const float alpha, float *X, const int incX)
{
sscal_(&n, &alpha, X, &incX);
}
static inline
void scal( const int n, const double alpha, double *X, const int incX)
{
dscal_(&n, &alpha, X, &incX);
}
static inline
void scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX)
{
cscal_(&n, &alpha, X, &incX);
}
static inline
void scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX)
{
zscal_(&n, &alpha, X, &incX);
}
static
void scal( const int n, const float alpha, float *X, const int incX);

static
void scal( const int n, const double alpha, double *X, const int incX);

static
void scal( const int n, const std::complex<float> alpha, std::complex<float> *X, const int incX);

static
void scal( const int n, const std::complex<double> alpha, std::complex<double> *X, const int incX);


// Peize Lin add 2017-10-27
// d=x*y
static inline
float dot( const int n, const float *X, const int incX, const float *Y, const int incY)
{
return sdot_(&n, X, &incX, Y, &incY);
}
static inline
double dot( const int n, const double *X, const int incX, const double *Y, const int incY)
{
return ddot_(&n, X, &incX, Y, &incY);
}
static
float dot( const int n, const float *X, const int incX, const float *Y, const int incY);

static
double dot( const int n, const double *X, const int incX, const double *Y, const int incY);


// Peize Lin add 2017-10-27, fix bug trans 2019-01-17
// C = a * A.? * B.? + b * C
static inline
static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const float alpha, const float *a, const int lda, const float *b, const int ldb,
const float beta, float *c, const int ldc)
{
sgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
static inline
const float beta, float *c, const int ldc);

static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const double alpha, const double *a, const int lda, const double *b, const int ldb,
const double beta, double *c, const int ldc)
{
dgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
static inline
const double beta, double *c, const int ldc);

static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<float> alpha, const std::complex<float> *a, const int lda, const std::complex<float> *b, const int ldb,
const std::complex<float> beta, std::complex<float> *c, const int ldc)
{
cgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
static inline
const std::complex<float> beta, std::complex<float> *c, const int ldc);

static
void gemm(const char transa, const char transb, const int m, const int n, const int k,
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc)
{
zgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
static inline
const std::complex<double> beta, std::complex<double> *c, const int ldc);

static
void gemv(const char trans, const int m, const int n,
const double alpha, const double* A, const int lda, const double* X, const int incx,
const double beta, double* Y, const int incy)
{
dgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
static inline
void gemv(const char trans, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incx,
const std::complex<float> beta, std::complex<float> *Y, const int incy)
{
cgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
static inline
const double beta, double* Y, const int incy);

static
void gemv(const char trans, const int m, const int n,
const std::complex<float> alpha, const std::complex<float> *A, const int lda, const std::complex<float> *X, const int incx,
const std::complex<float> beta, std::complex<float> *Y, const int incy);

static
void gemv(const char trans, const int m, const int n,
const std::complex<double> alpha, const std::complex<double> *A, const int lda, const std::complex<double> *X, const int incx,
const std::complex<double> beta, std::complex<double> *Y, const int incy)
{
zgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
const std::complex<double> beta, std::complex<double> *Y, const int incy);


// Peize Lin add 2018-06-12
// out = ||x||_2
static inline
float nrm2( const int n, const float *X, const int incX )
{
return snrm2_( &n, X, &incX );
}
static inline
double nrm2( const int n, const double *X, const int incX )
{
return dnrm2_( &n, X, &incX );
}
static inline
double nrm2( const int n, const std::complex<double> *X, const int incX )
{
return dznrm2_( &n, X, &incX );
}
static
float nrm2( const int n, const float *X, const int incX );

static
double nrm2( const int n, const double *X, const int incX );

static
double nrm2( const int n, const std::complex<double> *X, const int incX );


// copies a into b
static inline
void copy(const long n, const double *a, const int incx, double *b, const int incy)
{
dcopy_(&n, a, &incx, b, &incy);
}
static inline
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy)
{
zcopy_(&n, a, &incx, b, &incy);
}
static
void copy(const long n, const double *a, const int incx, double *b, const int incy);

static
void copy(const long n, const std::complex<double> *a, const int incx, std::complex<double> *b, const int incy);
};

// If GATHER_INFO is defined, the original function is replaced with a "i" suffix,
Expand Down Expand Up @@ -308,4 +258,4 @@ void zgemv_i(const char *trans,
*/

#endif // GATHER_INFO
#endif // BLAS_CONNECTOR_H
#endif // BLAS_CONNECTOR_H
Loading
Loading