Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#include "module_base/global_variable.h"
#endif

void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type)
Expand Down Expand Up @@ -94,7 +95,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
else if (device_type == base_device::AbacusDevice_t::DspDevice){
sgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}
Expand All @@ -112,7 +113,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
else if (device_type == base_device::AbacusDevice_t::DspDevice){
dgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}
Expand All @@ -130,7 +131,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
cgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}
Expand All @@ -148,7 +149,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons
else if (device_type == base_device::AbacusDevice_t::DspDevice) {
zgemm_mt_(&transb, &transa, &n, &m, &k,
&alpha, b, &ldb, a, &lda,
&beta, c, &ldc);
&beta, c, &ldc, GlobalV::MY_RANK);
}
#endif
}
Expand Down
20 changes: 10 additions & 10 deletions source/module_base/kernels/dsp/dsp_connector.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

// Base dsp functions
void dspInitHandle(int id);
void dspDestoryHandle();
void *malloc_ht(size_t bytes);
void dspDestoryHandle(int id);
void *malloc_ht(size_t bytes, int cluster_id);
void free_ht(void* ptr);


Expand All @@ -15,50 +15,50 @@ void sgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);
float *c, const int *ldc, int cluster_id);

void dgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
double *c, const int *ldc, int cluster_id);

void zgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);
std::complex<double> *c, const int *ldc, int cluster_id);

void cgemm_mt_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);
std::complex<float> *c, const int *ldc, int cluster_id);


void sgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const float *alpha, const float *a, const int *lda,
const float *b, const int *ldb, const float *beta,
float *c, const int *ldc);
float *c, const int *ldc, int cluster_id);

void dgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const double *alpha,const double *a, const int *lda,
const double *b, const int *ldb, const double *beta,
double *c, const int *ldc);
double *c, const int *ldc, int cluster_id);

void zgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<double> *alpha, const std::complex<double> *a, const int *lda,
const std::complex<double> *b, const int *ldb, const std::complex<double> *beta,
std::complex<double> *c, const int *ldc);
std::complex<double> *c, const int *ldc, int cluster_id);

void cgemm_mth_(const char *transa, const char *transb,
const int *m, const int *n, const int *k,
const std::complex<float> *alpha, const std::complex<float> *a, const int *lda,
const std::complex<float> *b, const int *ldb, const std::complex<float> *beta,
std::complex<float> *c, const int *ldc);
std::complex<float> *c, const int *ldc, int cluster_id);

//#define zgemm_ zgemm_mt

Expand Down
13 changes: 1 addition & 12 deletions source/module_base/module_device/memory_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "module_base/tool_threading.h"
#ifdef __DSP
#include "module_base/kernels/dsp/dsp_connector.h"
#include "module_base/global_variable.h"
#endif

#include <complex>
Expand All @@ -21,17 +22,9 @@ struct resize_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
if (arr != nullptr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
#ifdef __DSP
arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size);
#else
arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size);
#endif
std::string record_string;
if (record_in != nullptr)
{
Expand Down Expand Up @@ -103,11 +96,7 @@ struct delete_memory_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr)
{
#ifdef __DSP
free_ht(arr);
#else
free(arr);
#endif
}
};

Expand Down
4 changes: 2 additions & 2 deletions source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ ESolver_KS_PW<T, Device>::ESolver_KS_PW()
#endif
#ifdef __DSP
std::cout << " ** Initializing DSP Hardware..." << std::endl;
dspInitHandle(GlobalV::MY_RANK % 4);
dspInitHandle(GlobalV::MY_RANK);
#endif
}

Expand Down Expand Up @@ -102,7 +102,7 @@ ESolver_KS_PW<T, Device>::~ESolver_KS_PW()
}
#ifdef __DSP
std::cout << " ** Closing DSP Hardware..." << std::endl;
dspDestoryHandle();
dspDestoryHandle(GlobalV::MY_RANK);
#endif
if (PARAM.inp.precision == "single")
{
Expand Down
Loading