Skip to content
3 changes: 3 additions & 0 deletions source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ list(APPEND device_srcs
module_hamilt_pw/hamilt_pwdft/kernels/veff_op.cpp
module_hamilt_pw/hamilt_pwdft/kernels/ekinetic_op.cpp
module_hamilt_pw/hamilt_pwdft/kernels/meta_op.cpp
module_hamilt_pw/hamilt_stodft/kernels/hpsi_norm_op.cpp
module_basis/module_pw/kernels/pw_op.cpp
module_hsolver/kernels/dngvd_op.cpp
module_hsolver/kernels/math_kernel_op.cpp
Expand All @@ -55,6 +56,7 @@ if(USE_CUDA)
module_hamilt_pw/hamilt_pwdft/kernels/cuda/veff_op.cu
module_hamilt_pw/hamilt_pwdft/kernels/cuda/ekinetic_op.cu
module_hamilt_pw/hamilt_pwdft/kernels/cuda/meta_op.cu
module_hamilt_pw/hamilt_stodft/kernels/cuda/hpsi_norm_op.cu
module_basis/module_pw/kernels/cuda/pw_op.cu
module_hsolver/kernels/cuda/dngvd_op.cu
module_hsolver/kernels/cuda/math_kernel_op.cu
Expand All @@ -78,6 +80,7 @@ if(USE_ROCM)
module_hamilt_pw/hamilt_pwdft/kernels/rocm/veff_op.hip.cu
module_hamilt_pw/hamilt_pwdft/kernels/rocm/ekinetic_op.hip.cu
module_hamilt_pw/hamilt_pwdft/kernels/rocm/meta_op.hip.cu
module_hamilt_pw/hamilt_stodft/kernels/rocm/hpsi_norm_op.hip.cu
module_basis/module_pw/kernels/rocm/pw_op.hip.cu
module_hsolver/kernels/rocm/dngvd_op.hip.cu
module_hsolver/kernels/rocm/math_kernel_op.hip.cu
Expand Down
2 changes: 2 additions & 0 deletions source/Makefile.Objects
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ VPATH=./src_global:\
./module_hamilt_pw/hamilt_stodft:\
./module_hamilt_pw/hamilt_pwdft/operator_pw:\
./module_hamilt_pw/hamilt_pwdft/kernels:\
./module_hamilt_pw/hamilt_stodft/kernels:\
./module_hamilt_lcao/module_hcontainer:\
./module_hamilt_lcao/hamilt_lcaodft:\
./module_hamilt_lcao/module_tddft:\
Expand Down Expand Up @@ -295,6 +296,7 @@ OBJS_HAMILT=hamilt_pw.o\
operator_pw.o\
ekinetic_pw.o\
ekinetic_op.o\
hpsi_norm_op.o\
veff_pw.o\
veff_op.o\
nonlocal_pw.o\
Expand Down
2 changes: 2 additions & 0 deletions source/module_base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ add_library(
${LIBM_SRC}
)

target_link_libraries(base PUBLIC container)

add_subdirectory(module_container)

if(ENABLE_COVERAGE)
Expand Down
9 changes: 9 additions & 0 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
#endif
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
sgemv_(&trans, &m, &n, &alpha, A, &lda, X, &incx, &beta, Y, &incy);
}
}

void BlasConnector::gemv(const char trans, const int m, const int n,
const double alpha, const double* A, const int lda, const double* X, const int incx,
const double beta, double* Y, const int incy, base_device::AbacusDevice_t device_type)
Expand Down
8 changes: 8 additions & 0 deletions source/module_base/blas_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ extern "C"
double dznrm2_( const int *n, const std::complex<double> *X, const int *incX );

// level 2: matrix-std::vector operations, O(n^2) data and O(n^2) work.
void sgemv_(const char*const transa, const int*const m, const int*const n,
const float*const alpha, const float*const a, const int*const lda, const float*const x, const int*const incx,
const float*const beta, float*const y, const int*const incy);
void dgemv_(const char*const transa, const int*const m, const int*const n,
const double*const alpha, const double*const a, const int*const lda, const double*const x, const int*const incx,
const double*const beta, double*const y, const int*const incy);
Expand Down Expand Up @@ -178,6 +181,11 @@ class BlasConnector
const std::complex<double> alpha, const std::complex<double> *a, const int lda, const std::complex<double> *b, const int ldb,
const std::complex<double> beta, std::complex<double> *c, const int ldc, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemv(const char trans, const int m, const int n,
const float alpha, const float* A, const int lda, const float* X, const int incx,
const float beta, float* Y, const int incy, base_device::AbacusDevice_t device_type = base_device::AbacusDevice_t::CpuDevice);

static
void gemv(const char trans, const int m, const int n,
const double alpha, const double* A, const int lda, const double* X, const int incx,
Expand Down
Loading
Loading