Skip to content

Commit ad77994

Browse files
authored
Fix: Add DSP gemm pack with auto memcpy to buffer (#7060)
* Add DSP gemm pack with auto memcpy to buffer * Fix and reorder gemm signature
1 parent e277c1c commit ad77994

3 files changed

Lines changed: 177 additions & 19 deletions

File tree

source/source_base/kernels/dsp/dsp_connector.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,66 @@ void zgemm_mth_(const char* transa,
403403
free_ht(bet);
404404
} // zgemm that needn't malloc_ht or free_ht
405405

406+
void zgemm_pack_mth_(const char* transa,
407+
const char* transb,
408+
const int* m,
409+
const int* n,
410+
const int* k,
411+
const std::complex<double>* alpha,
412+
const std::complex<double>* a,
413+
const int* lda,
414+
const std::complex<double>* b,
415+
const int* ldb,
416+
const std::complex<double>* beta,
417+
std::complex<double>* c,
418+
const int* ldc,
419+
int cluster_id)
420+
{
421+
const bool transa_not = (transa[0] == 'N' || transa[0] == 'n');
422+
const bool transb_not = (transb[0] == 'N' || transb[0] == 'n');
423+
// const size_t a_elems = static_cast<size_t>(*lda) * (transa_not ? static_cast<size_t>(*k) : static_cast<size_t>(*m));
424+
// const size_t b_elems = static_cast<size_t>(*ldb) * (transb_not ? static_cast<size_t>(*n) : static_cast<size_t>(*k));
425+
const size_t c_elems = static_cast<size_t>(*ldc) * static_cast<size_t>(*n);
426+
427+
// std::complex<double>* A_dsp = static_cast<std::complex<double>*>(malloc_ht(a_elems * sizeof(std::complex<double>), cluster_id));
428+
// std::complex<double>* B_dsp = static_cast<std::complex<double>*>(malloc_ht(b_elems * sizeof(std::complex<double>), cluster_id));
429+
std::complex<double>* C_dsp = static_cast<std::complex<double>*>(malloc_ht(c_elems * sizeof(std::complex<double>), cluster_id));
430+
std::complex<double>* alp = static_cast<std::complex<double>*>(malloc_ht(sizeof(std::complex<double>), cluster_id));
431+
std::complex<double>* bet = static_cast<std::complex<double>*>(malloc_ht(sizeof(std::complex<double>), cluster_id));
432+
433+
// memcpy(A_dsp, a, a_elems * sizeof(std::complex<double>));
434+
// memcpy(B_dsp, b, b_elems * sizeof(std::complex<double>));
435+
memcpy(C_dsp, c, c_elems * sizeof(std::complex<double>));
436+
*alp = *alpha;
437+
*bet = *beta;
438+
439+
mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor,
440+
convertBLASTranspose(transa),
441+
convertBLASTranspose(transb),
442+
*m,
443+
*n,
444+
*k,
445+
alp,
446+
a,
447+
// A_dsp,
448+
*lda,
449+
b,
450+
// B_dsp,
451+
*ldb,
452+
bet,
453+
// c,
454+
C_dsp,
455+
*ldc,
456+
cluster_id);
457+
memcpy(c, C_dsp, c_elems * sizeof(std::complex<double>));
458+
459+
// free_ht(A_dsp);
460+
// free_ht(B_dsp);
461+
free_ht(C_dsp);
462+
free_ht(alp);
463+
free_ht(bet);
464+
}
465+
406466
void cgemm_mth_(const char* transa,
407467
const char* transb,
408468
const int* m,
@@ -443,6 +503,64 @@ void cgemm_mth_(const char* transa,
443503
free_ht(bet);
444504
} // cgemm that needn't malloc_ht or free_ht
445505

506+
void cgemm_pack_mth_(const char* transa,
507+
const char* transb,
508+
const int* m,
509+
const int* n,
510+
const int* k,
511+
const std::complex<float>* alpha,
512+
const std::complex<float>* a,
513+
const int* lda,
514+
const std::complex<float>* b,
515+
const int* ldb,
516+
const std::complex<float>* beta,
517+
std::complex<float>* c,
518+
const int* ldc,
519+
int cluster_id)
520+
{
521+
const bool transa_not = (transa[0] == 'N' || transa[0] == 'n');
522+
const bool transb_not = (transb[0] == 'N' || transb[0] == 'n');
523+
const size_t a_elems = static_cast<size_t>(*lda) * (transa_not ? static_cast<size_t>(*k) : static_cast<size_t>(*m));
524+
const size_t b_elems = static_cast<size_t>(*ldb) * (transb_not ? static_cast<size_t>(*n) : static_cast<size_t>(*k));
525+
const size_t c_elems = static_cast<size_t>(*ldc) * static_cast<size_t>(*n);
526+
527+
std::complex<float>* A_dsp = static_cast<std::complex<float>*>(malloc_ht(a_elems * sizeof(std::complex<float>), cluster_id));
528+
std::complex<float>* B_dsp = static_cast<std::complex<float>*>(malloc_ht(b_elems * sizeof(std::complex<float>), cluster_id));
529+
std::complex<float>* C_dsp = static_cast<std::complex<float>*>(malloc_ht(c_elems * sizeof(std::complex<float>), cluster_id));
530+
std::complex<float>* alp = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));
531+
std::complex<float>* bet = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));
532+
533+
memcpy(A_dsp, a, a_elems * sizeof(std::complex<float>));
534+
memcpy(B_dsp, b, b_elems * sizeof(std::complex<float>));
535+
memcpy(C_dsp, c, c_elems * sizeof(std::complex<float>));
536+
*alp = *alpha;
537+
*bet = *beta;
538+
539+
mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor,
540+
convertBLASTranspose(transa),
541+
convertBLASTranspose(transb),
542+
*m,
543+
*n,
544+
*k,
545+
(const void*)alp,
546+
(const void*)A_dsp,
547+
*lda,
548+
(const void*)B_dsp,
549+
*ldb,
550+
(const void*)bet,
551+
(void*)C_dsp,
552+
*ldc,
553+
cluster_id);
554+
555+
memcpy(c, C_dsp, c_elems * sizeof(std::complex<float>));
556+
557+
free_ht(A_dsp);
558+
free_ht(B_dsp);
559+
free_ht(C_dsp);
560+
free_ht(alp);
561+
free_ht(bet);
562+
}
563+
446564
void sgemv_mth_(const char* transa,
447565
const int* m,
448566
const int* n,
@@ -570,4 +688,4 @@ void cgemv_mth_(const char* transa,
570688
free_ht(alp);
571689
free_ht(bet);
572690
}
573-
} // namespace mtfunc
691+
} // namespace mtfunc

source/source_base/kernels/dsp/dsp_connector.h

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,21 @@ void zgemm_mt_(const char* transa,
6262
int cluster_id);
6363

6464
void cgemm_mt_(const char* transa,
65-
const char* transb,
66-
const int* m,
67-
const int* n,
68-
const int* k,
69-
const std::complex<float>* alpha,
70-
const std::complex<float>* a,
71-
const int* lda,
72-
const std::complex<float>* b,
73-
const int* ldb,
74-
const std::complex<float>* beta,
75-
std::complex<float>* c,
76-
const int* ldc,
77-
int cluster_id);
65+
const char* transb,
66+
const int* m,
67+
const int* n,
68+
const int* k,
69+
const std::complex<float>* alpha,
70+
const std::complex<float>* a,
71+
const int* lda,
72+
const std::complex<float>* b,
73+
const int* ldb,
74+
const std::complex<float>* beta,
75+
std::complex<float>* c,
76+
const int* ldc,
77+
int cluster_id);
78+
79+
7880

7981
void sgemv_mt_(const char* transa,
8082
const int* m,
@@ -173,6 +175,21 @@ void zgemm_mth_(const char* transa,
173175
const int* ldc,
174176
int cluster_id);
175177

178+
void zgemm_pack_mth_(const char* transa,
179+
const char* transb,
180+
const int* m,
181+
const int* n,
182+
const int* k,
183+
const std::complex<double>* alpha,
184+
const std::complex<double>* a,
185+
const int* lda,
186+
const std::complex<double>* b,
187+
const int* ldb,
188+
const std::complex<double>* beta,
189+
std::complex<double>* c,
190+
const int* ldc,
191+
int cluster_id);
192+
176193
void cgemm_mth_(const char* transa,
177194
const char* transb,
178195
const int* m,
@@ -188,6 +205,21 @@ void cgemm_mth_(const char* transa,
188205
const int* ldc,
189206
int cluster_id);
190207

208+
void cgemm_pack_mth_(const char* transa,
209+
const char* transb,
210+
const int* m,
211+
const int* n,
212+
const int* k,
213+
const std::complex<float>* alpha,
214+
const std::complex<float>* a,
215+
const int* lda,
216+
const std::complex<float>* b,
217+
const int* ldb,
218+
const std::complex<float>* beta,
219+
std::complex<float>* c,
220+
const int* ldc,
221+
int cluster_id);
222+
191223
void sgemv_mth_(const char* transa,
192224
const int* m,
193225
const int* n,
@@ -282,4 +314,4 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
282314
} // namespace mtfunc
283315

284316
#endif
285-
#endif
317+
#endif

source/source_base/module_external/blas_connector_matrix.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ void BlasConnector::gemm(const char transa,
107107
#ifdef __DSP
108108
else if (device_type == base_device::AbacusDevice_t::DspDevice)
109109
{
110-
mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
110+
mtfunc::cgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
111+
// cgemm_mth_ for raw dsp mth;
112+
// cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
111113
}
112114
#endif
113115
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -158,7 +160,9 @@ void BlasConnector::gemm(const char transa,
158160
#ifdef __DSP
159161
else if (device_type == base_device::AbacusDevice_t::DspDevice)
160162
{
161-
mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
163+
mtfunc::zgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
164+
// zgemm_mth_ for raw dsp mth;
165+
// zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
162166
}
163167
#endif
164168
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -277,7 +281,9 @@ void BlasConnector::gemm_cm(const char transa,
277281
#ifdef __DSP
278282
else if (device_type == base_device::AbacusDevice_t::DspDevice)
279283
{
280-
mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
284+
mtfunc::cgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
285+
// cgemm_mth_ for raw dsp mth;
286+
// cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
281287
}
282288
#endif
283289
#ifdef __CUDA
@@ -328,7 +334,9 @@ void BlasConnector::gemm_cm(const char transa,
328334
#ifdef __DSP
329335
else if (device_type == base_device::AbacusDevice_t::DspDevice)
330336
{
331-
mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
337+
mtfunc::zgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
338+
// zgemm_mth_ for raw dsp mth;
339+
// zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
332340
}
333341
#endif
334342
#ifdef __CUDA

0 commit comments

Comments
 (0)