66extern " C"
77{
88#define complex_double ignore_complex_double
9- #include < mt_hthread_blas.h> // MTBLAS_TRANSPOSE etc
9+ #include < mt_hthread_blas.h> // include faster mtblas kernels
1010#undef complex_double
11- #include < mtblas_interface.h> // gemm
11+ #include < mtblas_interface.h> // include normal mtblas kernels that automatically operate memory, but slower.
1212}
1313namespace mtfunc
1414{
@@ -22,45 +22,42 @@ void dspDestoryHandle(int id)
2222{
2323 hthread_dev_close (id);
2424 std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
25- } // Close dsp cluster at the end
25+ } // Close dsp cluster at the end of the program
2626
27- MTBLAS_TRANSPOSE convertBLASTranspose (const char * blasTrans)
27+ // MTBlas secretly removed its MTBLAS_TRANSPOSE data type and used the original CBLAS_TRANSPOSE. So this function is modified.
28+
29+ CBLAS_TRANSPOSE convertBLASTranspose (const char * blasTrans)
2830{
2931 switch (blasTrans[0 ])
3032 {
3133 case ' N' :
3234 case ' n' :
33- return MtblasNoTrans ;
35+ return CblasNoTrans ;
3436 case ' T' :
3537 case ' t' :
36- return MtblasTrans ;
38+ return CblasTrans ;
3739 case ' C' :
3840 case ' c' :
39- return MtblasConjTrans ;
41+ return CblasConjTrans ;
4042 default :
4143 std::cout << " Invalid BLAS transpose parameter!! Use default instead." << std::endl;
42- return MtblasNoTrans ;
44+ return CblasNoTrans ;
4345 }
44- } // Used to convert normal transpost char to mtblas transpose flag
46+ } // Used to convert normal transpost char to cblas transpose flag
4547
4648void * malloc_ht (size_t bytes, int cluster_id)
4749{
48- // std::cout << "MALLOC " << cluster_id;
4950 void * ptr = hthread_malloc ((int )cluster_id, bytes, HT_MEM_RW);
50- // std::cout << ptr << " SUCCEED" << std::endl;;
5151 return ptr;
52- }
52+ } // Malloc on dsp. Used to replace original malloc
53+
5354
54- // Used to replace original malloc
5555
5656void free_ht (void * ptr)
5757{
58- // std::cout << "FREE " << ptr;
5958 hthread_free (ptr);
60- // std::cout << " FREE SUCCEED" << std::endl;
61- }
59+ } // Free on dsp. Used to replace original free
6260
63- // Used to replace original free
6461
6562void sgemm_mt_ (const char * transa,
6663 const char * transb,
@@ -77,7 +74,7 @@ void sgemm_mt_(const char* transa,
7774 const int * ldc,
7875 int cluster_id)
7976{
80- mtblas_sgemm (MTBLAS_ORDER::MtblasColMajor ,
77+ mtblas_sgemm (CBLAS_ORDER::CblasColMajor ,
8178 convertBLASTranspose (transa),
8279 convertBLASTranspose (transb),
8380 *m,
@@ -109,7 +106,7 @@ void dgemm_mt_(const char* transa,
109106 const int * ldc,
110107 int cluster_id)
111108{
112- mtblas_dgemm (MTBLAS_ORDER::MtblasColMajor ,
109+ mtblas_dgemm (CBLAS_ORDER::CblasColMajor ,
113110 convertBLASTranspose (transa),
114111 convertBLASTranspose (transb),
115112 *m,
@@ -141,7 +138,7 @@ void zgemm_mt_(const char* transa,
141138 const int * ldc,
142139 int cluster_id)
143140{
144- mtblas_zgemm (MTBLAS_ORDER::MtblasColMajor ,
141+ mtblas_zgemm (CBLAS_ORDER::CblasColMajor ,
145142 convertBLASTranspose (transa),
146143 convertBLASTranspose (transb),
147144 *m,
@@ -173,7 +170,7 @@ void cgemm_mt_(const char* transa,
173170 const int * ldc,
174171 int cluster_id)
175172{
176- mtblas_cgemm (MTBLAS_ORDER::MtblasColMajor ,
173+ mtblas_cgemm (CBLAS_ORDER::CblasColMajor ,
177174 convertBLASTranspose (transa),
178175 convertBLASTranspose (transb),
179176 *m,
@@ -207,7 +204,7 @@ void sgemm_mth_(const char* transa,
207204 const int * ldc,
208205 int cluster_id)
209206{
210- mt_hthread_sgemm (MTBLAS_ORDER::MtblasColMajor ,
207+ mt_hthread_sgemm (CBLAS_ORDER::CblasColMajor ,
211208 convertBLASTranspose (transa),
212209 convertBLASTranspose (transb),
213210 *m,
@@ -239,7 +236,7 @@ void dgemm_mth_(const char* transa,
239236 const int * ldc,
240237 int cluster_id)
241238{
242- mt_hthread_dgemm (MTBLAS_ORDER::MtblasColMajor ,
239+ mt_hthread_dgemm (CBLAS_ORDER::CblasColMajor ,
243240 convertBLASTranspose (transa),
244241 convertBLASTranspose (transb),
245242 *m,
@@ -275,7 +272,7 @@ void zgemm_mth_(const char* transa,
275272 *alp = *alpha;
276273 std::complex <double >* bet = (std::complex <double >*)malloc_ht (sizeof (std::complex <double >), cluster_id);
277274 *bet = *beta;
278- mt_hthread_zgemm (MTBLAS_ORDER::MtblasColMajor ,
275+ mt_hthread_zgemm (CBLAS_ORDER::CblasColMajor ,
279276 convertBLASTranspose (transa),
280277 convertBLASTranspose (transb),
281278 *m,
@@ -314,7 +311,7 @@ void cgemm_mth_(const char* transa,
314311 std::complex <float >* bet = (std::complex <float >*)malloc_ht (sizeof (std::complex <float >), cluster_id);
315312 *bet = *beta;
316313
317- mt_hthread_cgemm (MTBLAS_ORDER::MtblasColMajor ,
314+ mt_hthread_cgemm (CBLAS_ORDER::CblasColMajor ,
318315 convertBLASTranspose (transa),
319316 convertBLASTranspose (transb),
320317 *m,
0 commit comments