Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ option(ENABLE_RAPIDJSON "Enable rapid-json usage." OFF)
option(ENABLE_CNPY "Enable cnpy usage." OFF)
option(ENABLE_PEXSI "Enable support for PEXSI." OFF)
option(ENABLE_CUSOLVERMP "Enable cusolvermp." OFF)
option(USE_DSP "Enable DSP usage." OFF)

# enable json support
if(ENABLE_RAPIDJSON)
Expand Down Expand Up @@ -119,6 +120,12 @@ elseif(ENABLE_LCAO AND NOT ENABLE_MPI)
set(ABACUS_BIN_NAME abacus_serial)
endif()

if (USE_DSP)
set(USE_ELPA OFF)
set(ENABLE_LCAO OFF)
set(ABACUS_BIN_NAME abacus_dsp)
endif()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)

if(ENABLE_COVERAGE)
Expand Down Expand Up @@ -240,6 +247,11 @@ if(ENABLE_MPI)
list(APPEND math_libs MPI::MPI_CXX)
endif()

if (USE_DSP)
target_link_libraries(${ABACUS_BIN_NAME} ${DIR_MTBLAS_LIBRARY})
add_compile_definitions(__DSP)
endif()

find_package(Threads REQUIRED)
target_link_libraries(${ABACUS_BIN_NAME} Threads::Threads)

Expand Down
10 changes: 10 additions & 0 deletions install_dsp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CXX=mpicxx \
cmake -B build \
-DUSE_DSP=ON \
-DENABLE_LCAO=OFF \
-DFFTW3_DIR=/vol8/appsoftware/fftw/ \
-DFFTW3_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3.so \
-DFFTW3_OMP_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3_omp.so \
-DFFTW3_FLOAT_LIBRARY=/vol8/appsoftware/fftw/lib/libfftw3f.so \
-DLAPACK_DIR=/vol8/appsoftware/openblas/0.3.21/lib \
-DDIR_MTBLAS_LIBRARY=/vol8/home/dptech_zyz1/develop/packages/libmtblas_abacus.so
45 changes: 41 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "blas_connector.h"

#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
Expand Down Expand Up @@ -64,13 +68,15 @@ float BlasConnector::dot( const int n, const float *X, const int incX, const flo
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return sdot_(&n, X, &incX, Y, &incY);
return sdot_(&n, X, &incX, Y, &incY);
}
}

double BlasConnector::dot( const int n, const double *X, const int incX, const double *Y, const int incY, base_device::AbacusDevice_t device_type)
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return ddot_(&n, X, &incX, Y, &incY);
return ddot_(&n, X, &incX, Y, &incY);
}
}

Expand All @@ -83,7 +89,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
sgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -94,7 +107,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
dgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -105,7 +125,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
cgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemm(const char transa, const char transb, const int m, const int n, const int k,
Expand All @@ -116,7 +143,14 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
zgemm_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
}
#ifdef __DSP
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
}
#endif
}

void BlasConnector::gemv(const char trans, const int m, const int n,
Expand Down Expand Up @@ -152,6 +186,7 @@ float BlasConnector::nrm2( const int n, const float *X, const int incX, base_dev
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return snrm2_( &n, X, &incX );
return snrm2_( &n, X, &incX );
}
}

Expand All @@ -160,6 +195,7 @@ double BlasConnector::nrm2( const int n, const double *X, const int incX, base_d
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dnrm2_( &n, X, &incX );
return dnrm2_( &n, X, &incX );
}
}

Expand All @@ -168,6 +204,7 @@ double BlasConnector::nrm2( const int n, const std::complex<double> *X, const in
{
if (device_type == base_device::AbacusDevice_t::CpuDevice) {
return dznrm2_( &n, X, &incX );
return dznrm2_( &n, X, &incX );
}
}

Expand Down
66 changes: 66 additions & 0 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef DSP_CONNECTOR_H
#define DSP_CONNECTOR_H
#ifdef __DSP

// Base dsp functions
void dspInitHandle(int id);
void dspDestoryHandle();
void *malloc_ht(size_t bytes);
void free_ht(void* ptr);


// mtblas functions

void sgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);

void dgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);

void zgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);

void cgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);


void sgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);

void dgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);

void zgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);

void cgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);

//#define zgemm_ zgemm_mt

#endif
#endif
15 changes: 15 additions & 0 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

#include "module_base/memory.h"
#include "module_base/tool_threading.h"
#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

#include <complex>
#include <cstring>
Expand All @@ -18,9 +21,17 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
if (arr != nullptr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
#ifdef __DSP
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
#else
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
#endif
std::string record_string;
if (record_in != nullptr)
{
Expand Down Expand Up @@ -92,7 +103,11 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
};

Expand Down
1 change: 1 addition & 0 deletions source/module_base/module_device/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ enum AbacusDevice_t
UnKnown,
CpuDevice,
GpuDevice,
DspDevice
};

} // namespace base_device
Expand Down
13 changes: 12 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
#include <ATen/kernels/blas.h>
#include <ATen/kernels/lapack.h>

#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#endif

namespace ModuleESolver
{

Expand All @@ -67,6 +71,10 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
container::kernels::createGpuSolverHandle();
}
#endif
#ifdef __DSP
std::cout << " ** Initializing DSP Hardware..." << std::endl;
dspInitHandle(GlobalV::MY_RANK % 4);
#endif
}

template <typename T, typename Device>
Expand All @@ -92,7 +100,10 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
#endif
delete reinterpret_cast<psi::Psi<T, Device>*>(this->kspw_psi);
}

#ifdef __DSP
std::cout << " ** Closing DSP Hardware..." << std::endl;
dspDestoryHandle();
#endif
if (PARAM.inp.precision == "single")
{
delete reinterpret_cast<psi::Psi<std::complex<double>, Device>*>(this->__kspw_psi);
Expand Down
42 changes: 36 additions & 6 deletions source/module_hsolver/diago_dav_subspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,12 @@ int Diago_DavSubspace<T, Device>::diag_once(const HPsiFunc& hpsi_func,
// updata eigenvectors of Hamiltonian
setmem_complex_op()(this->ctx, psi_in, 0, n_band * psi_in_dmax);

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -262,7 +267,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
}
}

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -302,7 +312,12 @@ void Diago_DavSubspace<T, Device>::cal_grad(const HPsiFunc& hpsi_func,
delmem_real_op()(this->ctx, e_temp_hd);
}

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down Expand Up @@ -386,7 +401,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "cal_elem");

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'C',
'N',
nbase + notconv,
Expand All @@ -401,7 +421,12 @@ void Diago_DavSubspace<T, Device>::cal_elem(const int& dim,
&hcc[nbase * this->nbase_x],
this->nbase_x);

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'C',
'N',
nbase + notconv,
Expand Down Expand Up @@ -603,7 +628,12 @@ void Diago_DavSubspace<T, Device>::refresh(const int& dim,
{
ModuleBase::timer::tick("Diago_DavSubspace", "refresh");

gemm_op<T, Device>()(this->ctx,
#ifdef __DSP
gemm_op_mt<T, Device>()
#else
gemm_op<T, Device>()
#endif
(this->ctx,
'N',
'N',
this->dim,
Expand Down
Loading
Loading