Skip to content

Commit b50284b

Browse files
committed
Add DSP gemm pack with auto memcpy to buffer
1 parent 7d641c1 commit b50284b

File tree

3 files changed

+177
-20
lines changed

3 files changed

+177
-20
lines changed

source/source_base/kernels/dsp/dsp_connector.cpp

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,66 @@ void zgemm_mth_(const char* transa,
403403
free_ht(bet);
404404
} // zgemm that needn't malloc_ht or free_ht
405405

406+
void zgemm_pack_mth_(const char* transa,
407+
const char* transb,
408+
const int* m,
409+
const int* n,
410+
const int* k,
411+
const std::complex<double>* alpha,
412+
const std::complex<double>* a,
413+
const int* lda,
414+
const std::complex<double>* b,
415+
const int* ldb,
416+
const std::complex<double>* beta,
417+
std::complex<double>* c,
418+
const int* ldc,
419+
int cluster_id)
420+
{
421+
const bool transa_not = (transa[0] == 'N' || transa[0] == 'n');
422+
const bool transb_not = (transb[0] == 'N' || transb[0] == 'n');
423+
// const size_t a_elems = static_cast<size_t>(*lda) * (transa_not ? static_cast<size_t>(*k) : static_cast<size_t>(*m));
424+
// const size_t b_elems = static_cast<size_t>(*ldb) * (transb_not ? static_cast<size_t>(*n) : static_cast<size_t>(*k));
425+
const size_t c_elems = static_cast<size_t>(*ldc) * static_cast<size_t>(*n);
426+
427+
// std::complex<double>* A_dsp = static_cast<std::complex<double>*>(malloc_ht(a_elems * sizeof(std::complex<double>), cluster_id));
428+
// std::complex<double>* B_dsp = static_cast<std::complex<double>*>(malloc_ht(b_elems * sizeof(std::complex<double>), cluster_id));
429+
std::complex<double>* C_dsp = static_cast<std::complex<double>*>(malloc_ht(c_elems * sizeof(std::complex<double>), cluster_id));
430+
std::complex<double>* alp = static_cast<std::complex<double>*>(malloc_ht(sizeof(std::complex<double>), cluster_id));
431+
std::complex<double>* bet = static_cast<std::complex<double>*>(malloc_ht(sizeof(std::complex<double>), cluster_id));
432+
433+
// memcpy(A_dsp, a, a_elems * sizeof(std::complex<double>));
434+
// memcpy(B_dsp, b, b_elems * sizeof(std::complex<double>));
435+
memcpy(C_dsp, c, c_elems * sizeof(std::complex<double>));
436+
*alp = *alpha;
437+
*bet = *beta;
438+
439+
mt_hthread_zgemm(CBLAS_ORDER::CblasColMajor,
440+
convertBLASTranspose(transa),
441+
convertBLASTranspose(transb),
442+
*m,
443+
*n,
444+
*k,
445+
alp,
446+
a,
447+
// A_dsp,
448+
*lda,
449+
b,
450+
// B_dsp,
451+
*ldb,
452+
bet,
453+
// c,
454+
C_dsp,
455+
*ldc,
456+
cluster_id);
457+
memcpy(c, C_dsp, c_elems * sizeof(std::complex<double>));
458+
459+
// free_ht(A_dsp);
460+
// free_ht(B_dsp);
461+
free_ht(C_dsp);
462+
free_ht(alp);
463+
free_ht(bet);
464+
}
465+
406466
void cgemm_mth_(const char* transa,
407467
const char* transb,
408468
const int* m,
@@ -443,6 +503,64 @@ void cgemm_mth_(const char* transa,
443503
free_ht(bet);
444504
} // cgemm that needn't malloc_ht or free_ht
445505

506+
void cgemm_pack_mth_(const char* transa,
507+
const char* transb,
508+
const int* m,
509+
const int* n,
510+
const int* k,
511+
const std::complex<float>* alpha,
512+
const std::complex<float>* a,
513+
const int* lda,
514+
const std::complex<float>* b,
515+
const int* ldb,
516+
const std::complex<float>* beta,
517+
std::complex<float>* c,
518+
const int* ldc,
519+
int cluster_id)
520+
{
521+
const bool transa_not = (transa[0] == 'N' || transa[0] == 'n');
522+
const bool transb_not = (transb[0] == 'N' || transb[0] == 'n');
523+
const size_t a_elems = static_cast<size_t>(*lda) * (transa_not ? static_cast<size_t>(*k) : static_cast<size_t>(*m));
524+
const size_t b_elems = static_cast<size_t>(*ldb) * (transb_not ? static_cast<size_t>(*n) : static_cast<size_t>(*k));
525+
const size_t c_elems = static_cast<size_t>(*ldc) * static_cast<size_t>(*n);
526+
527+
std::complex<float>* A_dsp = static_cast<std::complex<float>*>(malloc_ht(a_elems * sizeof(std::complex<float>), cluster_id));
528+
std::complex<float>* B_dsp = static_cast<std::complex<float>*>(malloc_ht(b_elems * sizeof(std::complex<float>), cluster_id));
529+
std::complex<float>* C_dsp = static_cast<std::complex<float>*>(malloc_ht(c_elems * sizeof(std::complex<float>), cluster_id));
530+
std::complex<float>* alp = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));
531+
std::complex<float>* bet = static_cast<std::complex<float>*>(malloc_ht(sizeof(std::complex<float>), cluster_id));
532+
533+
memcpy(A_dsp, a, a_elems * sizeof(std::complex<float>));
534+
memcpy(B_dsp, b, b_elems * sizeof(std::complex<float>));
535+
memcpy(C_dsp, c, c_elems * sizeof(std::complex<float>));
536+
*alp = *alpha;
537+
*bet = *beta;
538+
539+
mt_hthread_cgemm(CBLAS_ORDER::CblasColMajor,
540+
convertBLASTranspose(transa),
541+
convertBLASTranspose(transb),
542+
*m,
543+
*n,
544+
*k,
545+
(const void*)alp,
546+
(const void*)A_dsp,
547+
*lda,
548+
(const void*)B_dsp,
549+
*ldb,
550+
(const void*)bet,
551+
(void*)C_dsp,
552+
*ldc,
553+
cluster_id);
554+
555+
memcpy(c, C_dsp, c_elems * sizeof(std::complex<float>));
556+
557+
free_ht(A_dsp);
558+
free_ht(B_dsp);
559+
free_ht(C_dsp);
560+
free_ht(alp);
561+
free_ht(bet);
562+
}
563+
446564
void sgemv_mth_(const char* transa,
447565
const int* m,
448566
const int* n,
@@ -570,4 +688,4 @@ void cgemv_mth_(const char* transa,
570688
free_ht(alp);
571689
free_ht(bet);
572690
}
573-
} // namespace mtfunc
691+
} // namespace mtfunc

source/source_base/kernels/dsp/dsp_connector.h

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,51 @@ void zgemm_mt_(const char* transa,
6161
const int* ldc,
6262
int cluster_id);
6363

64-
void cgemm_mt_(const char* transa,
65-
const char* transb,
66-
const int* m,
67-
const int* n,
68-
const int* k,
69-
const std::complex<float>* alpha,
70-
const std::complex<float>* a,
71-
const int* lda,
72-
const std::complex<float>* b,
73-
const int* ldb,
74-
const std::complex<float>* beta,
75-
std::complex<float>* c,
76-
const int* ldc,
77-
int cluster_id);
64+
void zgemm_pack_mth_(const char* transa,
65+
const char* transb,
66+
const int* m,
67+
const int* n,
68+
const int* k,
69+
const std::complex<double>* alpha,
70+
const std::complex<double>* a,
71+
const int* lda,
72+
const std::complex<double>* b,
73+
const int* ldb,
74+
const std::complex<double>* beta,
75+
std::complex<double>* c,
76+
const int* ldc,
77+
int cluster_id);
78+
79+
80+
void cgemm_mth_(const char* transa,
81+
const char* transb,
82+
const int* m,
83+
const int* n,
84+
const int* k,
85+
const std::complex<float>* alpha,
86+
const std::complex<float>* a,
87+
const int* lda,
88+
const std::complex<float>* b,
89+
const int* ldb,
90+
const std::complex<float>* beta,
91+
std::complex<float>* c,
92+
const int* ldc,
93+
int cluster_id);
94+
95+
void cgemm_pack_mth_(const char* transa,
96+
const char* transb,
97+
const int* m,
98+
const int* n,
99+
const int* k,
100+
const std::complex<float>* alpha,
101+
const std::complex<float>* a,
102+
const int* lda,
103+
const std::complex<float>* b,
104+
const int* ldb,
105+
const std::complex<float>* beta,
106+
std::complex<float>* c,
107+
const int* ldc,
108+
int cluster_id);
78109

79110
void sgemv_mt_(const char* transa,
80111
const int* m,
@@ -282,4 +313,4 @@ void dsp_dav_subspace_reduce(T* hcc, T* scc, int nbase, int nbase_x, int notconv
282313
} // namespace mtfunc
283314

284315
#endif
285-
#endif
316+
#endif

source/source_base/module_external/blas_connector_matrix.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ void BlasConnector::gemm(const char transa,
107107
#ifdef __DSP
108108
else if (device_type == base_device::AbacusDevice_t::DspDevice)
109109
{
110-
mtfunc::cgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
110+
mtfunc::cgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
111+
// cgemm_mth_ for raw dsp mth;
112+
// cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
111113
}
112114
#endif
113115
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -158,7 +160,9 @@ void BlasConnector::gemm(const char transa,
158160
#ifdef __DSP
159161
else if (device_type == base_device::AbacusDevice_t::DspDevice)
160162
{
161-
mtfunc::zgemm_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
163+
mtfunc::zgemm_pack_mth_(&transb, &transa, &n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
164+
// zgemm_mth_ for raw dsp mth;
165+
// zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
162166
}
163167
#endif
164168
else if (device_type == base_device::AbacusDevice_t::GpuDevice)
@@ -277,7 +281,9 @@ void BlasConnector::gemm_cm(const char transa,
277281
#ifdef __DSP
278282
else if (device_type == base_device::AbacusDevice_t::DspDevice)
279283
{
280-
mtfunc::cgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
284+
mtfunc::cgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
285+
// cgemm_mth_ for raw dsp mth;
286+
// cgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
281287
}
282288
#endif
283289
#ifdef __CUDA
@@ -328,7 +334,9 @@ void BlasConnector::gemm_cm(const char transa,
328334
#ifdef __DSP
329335
else if (device_type == base_device::AbacusDevice_t::DspDevice)
330336
{
331-
mtfunc::zgemm_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
337+
mtfunc::zgemm_pack_mth_(&transa, &transb, &m, &n, &k, &alpha, a, &lda, b, &ldb, &beta, c, &ldc, GlobalV::MY_RANK % PARAM.inp.dsp_count);
338+
// zgemm_mth_ for raw dsp mth;
339+
// zgemm_pack_mth_ for dsp mth with memcpy to DSP buffer
332340
}
333341
#endif
334342
#ifdef __CUDA

0 commit comments

Comments
 (0)