Skip to content

Commit bc2d287

Browse files
authored
[Feature] Adapt ABACUS to newest version of mtblas and mtfft (#6548)
* Use new mtblas library * Update annotations * Modify fft
1 parent 5926001 commit bc2d287

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

source/source_base/kernels/dsp/dsp_connector.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
extern "C"
77
{
88
#define complex_double ignore_complex_double
9-
#include <mt_hthread_blas.h> // MTBLAS_TRANSPOSE etc
9+
#include <mt_hthread_blas.h> // include faster mtblas kernels
1010
#undef complex_double
11-
#include <mtblas_interface.h> // gemm
11+
#include <mtblas_interface.h> // include normal mtblas kernels that automatically operate memory, but slower.
1212
}
1313
namespace mtfunc
1414
{
@@ -22,45 +22,42 @@ void dspDestoryHandle(int id)
2222
{
2323
hthread_dev_close(id);
2424
std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
25-
} // Close dsp cluster at the end
25+
} // Close dsp cluster at the end of the program
2626

27-
MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
27+
// MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE. So this function is modified.
28+
29+
CBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
2830
{
2931
switch (blasTrans[0])
3032
{
3133
case 'N':
3234
case 'n':
33-
return MtblasNoTrans;
35+
return CblasNoTrans;
3436
case 'T':
3537
case 't':
36-
return MtblasTrans;
38+
return CblasTrans;
3739
case 'C':
3840
case 'c':
39-
return MtblasConjTrans;
41+
return CblasConjTrans;
4042
default:
4143
std::cout << "Invalid BLAS transpose parameter!! Use default instead." << std::endl;
42-
return MtblasNoTrans;
44+
return CblasNoTrans;
4345
}
44-
} // Used to convert normal transpost char to mtblas transpose flag
46+
} // Used to convert normal transpost char to cblas transpose flag
4547

4648
void* malloc_ht(size_t bytes, int cluster_id)
4749
{
48-
// std::cout << "MALLOC " << cluster_id;
4950
void* ptr = hthread_malloc((int)cluster_id, bytes, HT_MEM_RW);
50-
// std::cout << ptr << " SUCCEED" << std::endl;;
5151
return ptr;
52-
}
52+
} // Malloc on dsp. Used to replace original malloc
53+
5354

54-
// Used to replace original malloc
5555

5656
void free_ht(void* ptr)
5757
{
58-
// std::cout << "FREE " << ptr;
5958
hthread_free(ptr);
60-
// std::cout << " FREE SUCCEED" << std::endl;
61-
}
59+
} // Free on dsp. Used to replace original free
6260

63-
// Used to replace original free
6461

6562
void sgemm_mt_(const char* transa,
6663
const char* transb,
@@ -77,7 +74,7 @@ void sgemm_mt_(const char* transa,
7774
const int* ldc,
7875
int cluster_id)
7976
{
80-
mtblas_sgemm(MTBLAS_ORDER::MtblasColMajor,
77+
mtblas_sgemm(CBLAS_ORDER::CblasColMajor,
8178
convertBLASTranspose(transa),
8279
convertBLASTranspose(transb),
8380
*m,
@@ -109,7 +106,7 @@ void dgemm_mt_(const char* transa,
109106
const int* ldc,
110107
int cluster_id)
111108
{
112-
mtblas_dgemm(MTBLAS_ORDER::MtblasColMajor,
109+
mtblas_dgemm(CBLAS_ORDER::CblasColMajor,
113110
convertBLASTranspose(transa),
114111
convertBLASTranspose(transb),
115112
*m,
@@ -141,7 +138,7 @@ void zgemm_mt_(const char* transa,
141138
const int* ldc,
142139
int cluster_id)
143140
{
144-
mtblas_zgemm(MTBLAS_ORDER::MtblasColMajor,
141+
mtblas_zgemm(CBLAS_ORDER::CblasColMajor,
145142
convertBLASTranspose(transa),
146143
convertBLASTranspose(transb),
147144
*m,
@@ -173,7 +170,7 @@ void cgemm_mt_(const char* transa,
173170
const int* ldc,
174171
int cluster_id)
175172
{
176-
mtblas_cgemm(MTBLAS_ORDER::MtblasColMajor,
173+
mtblas_cgemm(CBLAS_ORDER::CblasColMajor,
177174
convertBLASTranspose(transa),
178175
convertBLASTranspose(transb),
179176
*m,
@@ -207,7 +204,7 @@ void sgemm_mth_(const char* transa,
207204
const int* ldc,
208205
int cluster_id)
209206
{
210-
mt_hthread_sgemm(MTBLAS_ORDER::MtblasColMajor,
207+
mt_hthread_sgemm(CBLAS_ORDER::CblasColMajor,
211208
convertBLASTranspose(transa),
212209
convertBLASTranspose(transb),
213210
*m,
@@ -239,7 +236,7 @@ void dgemm_mth_(const char* transa,
239236
const int* ldc,
240237
int cluster_id)
241238
{
242-
mt_hthread_dgemm(MTBLAS_ORDER::MtblasColMajor,
239+
mt_hthread_dgemm(CBLAS_ORDER::CblasColMajor,
243240
convertBLASTranspose(transa),
244241
convertBLASTranspose(transb),
245242
*m,
@@ -275,7 +272,7 @@ void zgemm_mth_(const char* transa,
275272
*alp = *alpha;
276273
std::complex<double>* bet = (std::complex<double>*)malloc_ht(sizeof(std::complex<double>), cluster_id);
277274
*bet = *beta;
278-
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
275+
mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor,
279276
convertBLASTranspose(transa),
280277
convertBLASTranspose(transb),
281278
*m,
@@ -314,7 +311,7 @@ void cgemm_mth_(const char* transa,
314311
std::complex<float>* bet = (std::complex<float>*)malloc_ht(sizeof(std::complex<float>), cluster_id);
315312
*bet = *beta;
316313

317-
mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
314+
mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor,
318315
convertBLASTranspose(transa),
319316
convertBLASTranspose(transb),
320317
*m,

source/source_base/module_fft/fft_dsp.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ void FFT_DSP<double>::resource_handler(const int flag) const
8383
template <>
8484
void FFT_DSP<double>::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const
8585
{
86-
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_for);
86+
hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_for);
8787
hthread_group_wait(thread_id_for);
8888
}
8989

9090
template <>
9191
void FFT_DSP<double>::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const
9292
{
93-
hthread_group_exec(thread_id_for, "execute_device", 1, 1, args_back);
93+
hthread_group_exec(thread_id_for, "execute_mtfft_3d", 1, 1, args_back);
9494
hthread_group_wait(thread_id_for);
9595
}
9696
template <>

0 commit comments

Comments
 (0)