From d82c508ed84411fc555f6b1b0068163bf983b4f0 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Thu, 24 Oct 2024 23:28:31 +0800 Subject: [PATCH 1/3] Fix parallel function --- source/module_base/blas_connector.cpp | 9 +++++---- source/module_base/kernels/dsp/dsp_connector.h | 16 ++++++++-------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/source/module_base/blas_connector.cpp b/source/module_base/blas_connector.cpp index 075e4df297..1de321ca99 100644 --- a/source/module_base/blas_connector.cpp +++ b/source/module_base/blas_connector.cpp @@ -2,6 +2,7 @@ #ifdef __DSP #include "module_base/kernels/dsp/dsp_connector.h" +#include "module_base/global_variable.h" #endif void BlasConnector::axpy( const int n, const float alpha, const float *X, const int incX, float *Y, const int incY, base_device::AbacusDevice_t device_type) @@ -94,7 +95,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons else if (device_type == base_device::AbacusDevice_t::DspDevice){ sgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); + &beta, c, &ldc, GlobalV::MY_RANK); } #endif } @@ -112,7 +113,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons else if (device_type == base_device::AbacusDevice_t::DspDevice){ dgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); + &beta, c, &ldc, GlobalV::MY_RANK); } #endif } @@ -130,7 +131,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons else if (device_type == base_device::AbacusDevice_t::DspDevice) { cgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); + &beta, c, &ldc, GlobalV::MY_RANK); } #endif } @@ -148,7 +149,7 @@ void BlasConnector::gemm(const char transa, const char transb, const int m, cons else if (device_type == base_device::AbacusDevice_t::DspDevice) { zgemm_mt_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, - &beta, c, &ldc); + &beta, c, &ldc, GlobalV::MY_RANK); } #endif } diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index 2d3075fcd1..0aaacca2d8 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -15,50 +15,50 @@ void sgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, - float *c, const int *ldc); + float *c, const int *ldc, int cluster_id); void dgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha,const double *a, const int *lda, const double *b, const int *ldb, const double *beta, - double *c, const int *ldc); + double *c, const int *ldc, int cluster_id); void zgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc); + std::complex *c, const int *ldc, int cluster_id); void cgemm_mt_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc); + std::complex *c, const int *ldc, int cluster_id); void sgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *a, const int *lda, const float *b, const int *ldb, const float *beta, - float *c, const int *ldc); + float *c, const int *ldc, int cluster_id); void dgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const double *alpha,const double *a, const int *lda, const double *b, const int *ldb, const double *beta, - double *c, const int *ldc); + double *c, const int *ldc, int cluster_id); void zgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc); + std::complex *c, const int *ldc, int cluster_id); void cgemm_mth_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const std::complex *alpha, const std::complex *a, const int *lda, const std::complex *b, const int *ldb, const std::complex *beta, - std::complex *c, const int *ldc); + std::complex *c, const int *ldc, int cluster_id); //#define zgemm_ zgemm_mt From 262ce2a37dd5d7ec723bbfc94792f7634e91bad0 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 25 Oct 2024 14:48:45 +0800 Subject: [PATCH 2/3] Fix parallel usage --- source/module_base/kernels/dsp/dsp_connector.h | 4 ++-- source/module_base/module_device/memory_op.cpp | 3 ++- source/module_esolver/esolver_ks_pw.cpp | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/source/module_base/kernels/dsp/dsp_connector.h b/source/module_base/kernels/dsp/dsp_connector.h index 0aaacca2d8..a928a0b095 100644 --- a/source/module_base/kernels/dsp/dsp_connector.h +++ b/source/module_base/kernels/dsp/dsp_connector.h @@ -4,8 +4,8 @@ // Base dsp functions void dspInitHandle(int id); -void dspDestoryHandle(); -void *malloc_ht(size_t bytes); +void dspDestoryHandle(int id); +void *malloc_ht(size_t bytes, int cluster_id); void free_ht(void* ptr); diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 625b535051..8f74c016bb 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -4,6 +4,7 @@ #include "module_base/tool_threading.h" #ifdef __DSP #include "module_base/kernels/dsp/dsp_connector.h" +#include "module_base/global_variable.h" #endif #include @@ -28,7 +29,7 @@ struct resize_memory_op #endif } #ifdef __DSP - arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size); + arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); #else arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size); #endif diff --git a/source/module_esolver/esolver_ks_pw.cpp b/source/module_esolver/esolver_ks_pw.cpp index bf6c0bc450..f79454891f 100644 --- a/source/module_esolver/esolver_ks_pw.cpp +++ b/source/module_esolver/esolver_ks_pw.cpp @@ -73,7 +73,7 @@ ESolver_KS_PW::ESolver_KS_PW() #endif #ifdef __DSP std::cout << " ** Initializing DSP Hardware..." << std::endl; - dspInitHandle(GlobalV::MY_RANK % 4); + dspInitHandle(GlobalV::MY_RANK); #endif } @@ -102,7 +102,7 @@ ESolver_KS_PW::~ESolver_KS_PW() } #ifdef __DSP std::cout << " ** Closing DSP Hardware..." << std::endl; - dspDestoryHandle(); + dspDestoryHandle(GlobalV::MY_RANK); #endif if (PARAM.inp.precision == "single") { From c5ecb4ff945a6066c7711cc649978726aea75910 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Fri, 25 Oct 2024 15:30:00 +0800 Subject: [PATCH 3/3] Temporarily remove memory_op porting --- source/module_base/module_device/memory_op.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/source/module_base/module_device/memory_op.cpp b/source/module_base/module_device/memory_op.cpp index 8f74c016bb..00c4a36ad7 100644 --- a/source/module_base/module_device/memory_op.cpp +++ b/source/module_base/module_device/memory_op.cpp @@ -22,17 +22,9 @@ struct resize_memory_op { if (arr != nullptr) { -#ifdef __DSP - free_ht(arr); -#else free(arr); -#endif } -#ifdef __DSP - arr = (FPTYPE*)malloc_ht(sizeof(FPTYPE) * size, GlobalV::MY_RANK); -#else arr = (FPTYPE*)malloc(sizeof(FPTYPE) * size); -#endif std::string record_string; if (record_in != nullptr) { @@ -104,11 +96,7 @@ struct delete_memory_op { void operator()(const base_device::DEVICE_CPU* dev, FPTYPE* arr) { -#ifdef __DSP - free_ht(arr); -#else free(arr); -#endif } };