Skip to content

Commit c752b6e

Browse files
committed
add change for the gemm
1 parent d7589cd commit c752b6e

File tree

2 files changed

+28
-28
lines changed

2 files changed

+28
-28
lines changed

source/source_base/kernels/dsp/dsp_connector.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,33 @@ extern "C"
1212
}
1313
namespace mtfunc
1414
{
15-
std::complex<double>* gemm_alp_double=nullptr;
16-
std::complex<double>* gemm_bet_double=nullptr;
17-
std::complex<float>* gemm_alp_float=nullptr;
18-
std::complex<float>* gemm_bet_float=nullptr;
15+
std::complex<double>* gemm_alpha_double=nullptr;
16+
std::complex<double>* gemm_beta_double=nullptr;
17+
std::complex<float>* gemm_alpha_float=nullptr;
18+
std::complex<float>* gemm_beta_float=nullptr;
1919

2020
void dspInitHandle(int id)
2121
{
2222
mt_blas_init(id);
2323
std::cout << " ** DSP inited on cluster " << id << " **" << std::endl;
24-
mtfunc::gemm_alp_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
25-
mtfunc::gemm_bet_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
26-
mtfunc::gemm_alp_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
27-
mtfunc::gemm_bet_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
24+
mtfunc::gemm_alpha_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
25+
mtfunc::gemm_beta_double=(std::complex<double>*)mtfunc::malloc_ht(sizeof(std::complex<double>), id);
26+
mtfunc::gemm_alpha_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
27+
mtfunc::gemm_beta_float=(std::complex<float>*)mtfunc::malloc_ht(sizeof(std::complex<float>), id);
2828
} // Use this at the beginning of the program to start a dsp cluster
2929

3030
void dspDestoryHandle(int id)
3131
{
3232
hthread_dev_close(id);
3333
std::cout << " ** DSP closed on cluster " << id << " **" << std::endl;
34-
mtfunc::free_ht(mtfunc::gemm_alp_double);
35-
mtfunc::free_ht(mtfunc::gemm_bet_double);
36-
mtfunc::free_ht(mtfunc::gemm_alp_float);
37-
mtfunc::free_ht(mtfunc::gemm_bet_float);
38-
mtfunc::gemm_alp_double = nullptr;
39-
mtfunc::gemm_bet_double = nullptr;
40-
mtfunc::gemm_alp_float = nullptr;
41-
mtfunc::gemm_bet_float = nullptr;
34+
mtfunc::free_ht(mtfunc::gemm_alpha_double);
35+
mtfunc::free_ht(mtfunc::gemm_beta_double);
36+
mtfunc::free_ht(mtfunc::gemm_alpha_float);
37+
mtfunc::free_ht(mtfunc::gemm_beta_float);
38+
mtfunc::gemm_alpha_double = nullptr;
39+
mtfunc::gemm_beta_double = nullptr;
40+
mtfunc::gemm_alpha_float = nullptr;
41+
mtfunc::gemm_beta_float = nullptr;
4242
} // Close dsp cluster at the end
4343

4444
MTBLAS_TRANSPOSE convertBLASTranspose(const char* blasTrans)
@@ -284,20 +284,20 @@ void zgemm_mth_(const char* transa,
284284
const int* ldc,
285285
int cluster_id)
286286
{
287-
*gemm_alp_double = *alpha;
288-
*gemm_bet_double = *beta;
287+
*gemm_alpha_double = *alpha;
288+
*gemm_beta_double = *beta;
289289
mt_hthread_zgemm(MTBLAS_ORDER::MtblasColMajor,
290290
convertBLASTranspose(transa),
291291
convertBLASTranspose(transb),
292292
*m,
293293
*n,
294294
*k,
295-
gemm_alp_double,
295+
gemm_alpha_double,
296296
a,
297297
*lda,
298298
b,
299299
*ldb,
300-
gemm_bet_double,
300+
gemm_beta_double,
301301
c,
302302
*ldc,
303303
cluster_id);
@@ -319,21 +319,21 @@ void cgemm_mth_(const char* transa,
319319
const int* ldc,
320320
int cluster_id)
321321
{
322-
gemm_alp_float = alpha;
323-
gemm_bet_float = beta;
322+
gemm_alpha_float = alpha;
323+
gemm_beta_float = beta;
324324

325325
mt_hthread_cgemm(MTBLAS_ORDER::MtblasColMajor,
326326
convertBLASTranspose(transa),
327327
convertBLASTranspose(transb),
328328
*m,
329329
*n,
330330
*k,
331-
(const void*)gemm_alp_float,
331+
(const void*)gemm_alpha_float,
332332
(const void*)a,
333333
*lda,
334334
(const void*)b,
335335
*ldb,
336-
(const void*)gemm_bet_float,
336+
(const void*)gemm_beta_float,
337337
(void*)c,
338338
*ldc,
339339
cluster_id);

source/source_base/kernels/dsp/dsp_connector.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ void* malloc_ht(size_t bytes, int cluster_id);
1515
void free_ht(void* ptr);
1616

1717
// mtblas functions
18-
extern std::complex<double>* gemm_alp_double;
19-
extern std::complex<double>* gemm_bet_double;
20-
extern std::complex<float>* gemm_alp_float;
21-
extern std::complex<float>* gemm_bet_float;
18+
extern std::complex<double>* gemm_alpha_double;
19+
extern std::complex<double>* gemm_beta_double;
20+
extern std::complex<float>* gemm_alpha_float;
21+
extern std::complex<float>* gemm_beta_float;
2222
void sgemm_mt_(const char* transa,
2323
const char* transb,
2424
const int* m,

0 commit comments

Comments
 (0)