From e8e86f6d39577541a338209cc5eb9d5586264edf Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Mon, 29 Sep 2025 14:56:55 +0800 Subject: [PATCH 1/3] Use new mtblas library --- .../source_base/kernels/dsp/dsp_connector.cpp | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index 01e68bfbc4..dc192d71d1 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -6,9 +6,9 @@ extern "C" { #define complex_double ignore_complex_double -#include // MTBLAS_TRANSPOSE etc +#include // include faster mtblas kernels #undef complex_double -#include // gemm +#include // include normal mtblas kernels that automatically operate memory, but slower. } namespace mtfunc { @@ -24,22 +24,24 @@ void dspDestoryHandle(int id) std::cout << " ** DSP closed on cluster " << id << " **" << std::endl; } // Close dsp cluster at the end -MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) +// MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE + +CBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) { switch (blasTrans[0]) { case 'N': case 'n': - return MtblasNoTrans; + return CblasNoTrans; case 'T': case 't': - return MtblasTrans; + return CblasTrans; case 'C': case 'c': - return MtblasConjTrans; + return CblasConjTrans; default: std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; - return MtblasNoTrans; + return CblasNoTrans; } } // Used to convert normal transpost char to mtblas transpose flag @@ -77,7 +79,7 @@ void sgemm_mt_(const char* transa, const int* ldc, int cluster_id) { - mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor, + mtblas_sgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -109,7 +111,7 @@ void dgemm_mt_(const char* transa, const int* ldc, int cluster_id) { - mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor, + mtblas_dgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -141,7 +143,7 @@ void zgemm_mt_(const char* transa, const int* ldc, int cluster_id) { - mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor, + mtblas_zgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -173,7 +175,7 @@ void cgemm_mt_(const char* transa, const int* ldc, int cluster_id) { - mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor, + mtblas_cgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -207,7 +209,7 @@ void sgemm_mth_(const char* transa, const int* ldc, int cluster_id) { - mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor, + mt_hthread_sgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -239,7 +241,7 @@ void dgemm_mth_(const char* transa, const int* ldc, int cluster_id) { - mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor, + mt_hthread_dgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -275,7 +277,7 @@ void zgemm_mth_(const char* transa, *alp = *alpha; std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); *bet = *beta; - mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor, + mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, @@ -314,7 +316,7 @@ void cgemm_mth_(const char* transa, std::complex* bet = (std::complex*)malloc_ht(sizeof(std::complex), cluster_id); *bet = *beta; - mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor, + mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor, convertBLASTranspose(transa), convertBLASTranspose(transb), *m, From 7ac340a0d5db55d91c7063b18dac63c0f21b0dcf Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Mon, 29 Sep 2025 14:58:34 +0800 Subject: [PATCH 2/3] Update annotations --- .../source_base/kernels/dsp/dsp_connector.cpp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/source/source_base/kernels/dsp/dsp_connector.cpp b/source/source_base/kernels/dsp/dsp_connector.cpp index dc192d71d1..3db4b47e8e 100644 --- a/source/source_base/kernels/dsp/dsp_connector.cpp +++ b/source/source_base/kernels/dsp/dsp_connector.cpp @@ -22,9 +22,9 @@ void dspDestoryHandle(int id) { hthread_dev_close(id); std::cout << " ** DSP closed on cluster " << id << " **" << std::endl; -} // Close dsp cluster at the end +} // Close dsp cluster at the end of the program -// MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE +// MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE. So this function is modified. CBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) { @@ -43,26 +43,21 @@ CBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans) std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl; return CblasNoTrans; } -} // Used to convert normal transpost char to mtblas transpose flag +} // Used to convert normal transpost char to cblas transpose flag void* malloc_ht(size_t bytes, int cluster_id) { - // std::cout << "MALLOC " << cluster_id; void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW); - // std::cout << ptr << " SUCCEED" << std::endl;; return ptr; -} +} // Malloc on dsp. Used to replace original malloc + -// Used to replace original malloc void free_ht(void* ptr) { - // std::cout << "FREE " << ptr; hthread_free(ptr); - // std::cout << " FREE SUCCEED" << std::endl; -} +} // Free on dsp. Used to replace original free -// Used to replace original free void sgemm_mt_(const char* transa, const char* transb, From 007740a9634a8dd3ec7f5ceb69443b938cf58de3 Mon Sep 17 00:00:00 2001 From: Critsium-xy Date: Mon, 29 Sep 2025 16:33:56 +0800 Subject: [PATCH 3/3] Modify fft --- source/source_base/module_fft/fft_dsp.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/source_base/module_fft/fft_dsp.cpp b/source/source_base/module_fft/fft_dsp.cpp index 1bd8463c91..82ea934f5b 100644 --- a/source/source_base/module_fft/fft_dsp.cpp +++ b/source/source_base/module_fft/fft_dsp.cpp @@ -83,14 +83,14 @@ void FFT_DSP::resource_handler(const int flag) const template <> void FFT_DSP::fft3D_forward(std::complex* in, std::complex* out) const { - hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for); + hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_for); hthread_group_wait(thread_id_for); } template <> void FFT_DSP::fft3D_backward(std::complex* in, std::complex* out) const { - hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back); + hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_back); hthread_group_wait(thread_id_for); } template <>