@@ -403,6 +403,66 @@ void zgemm_mth_(const char* transa,
403403 free_ht (bet);
404404} // zgemm that needn't malloc_ht or free_ht
405405
406+ void zgemm_pack_mth_ (const char * transa,
407+ const char * transb,
408+ const int * m,
409+ const int * n,
410+ const int * k,
411+ const std::complex <double >* alpha,
412+ const std::complex <double >* a,
413+ const int * lda,
414+ const std::complex <double >* b,
415+ const int * ldb,
416+ const std::complex <double >* beta,
417+ std::complex <double >* c,
418+ const int * ldc,
419+ int cluster_id)
420+ {
421+ const bool transa_not = (transa[0 ] == ' N' || transa[0 ] == ' n' );
422+ const bool transb_not = (transb[0 ] == ' N' || transb[0 ] == ' n' );
423+ // const size_t a_elems = static_cast<size_t>(*lda) * (transa_not ? static_cast<size_t>(*k) : static_cast<size_t>(*m));
424+ // const size_t b_elems = static_cast<size_t>(*ldb) * (transb_not ? static_cast<size_t>(*n) : static_cast<size_t>(*k));
425+ const size_t c_elems = static_cast <size_t >(*ldc) * static_cast <size_t >(*n);
426+
427+ // std::complex<double>* A_dsp = static_cast<std::complex<double>*>(malloc_ht(a_elems * sizeof(std::complex<double>), cluster_id));
428+ // std::complex<double>* B_dsp = static_cast<std::complex<double>*>(malloc_ht(b_elems * sizeof(std::complex<double>), cluster_id));
429+ std::complex <double >* C_dsp = static_cast <std::complex <double >*>(malloc_ht (c_elems * sizeof (std::complex <double >), cluster_id));
430+ std::complex <double >* alp = static_cast <std::complex <double >*>(malloc_ht (sizeof (std::complex <double >), cluster_id));
431+ std::complex <double >* bet = static_cast <std::complex <double >*>(malloc_ht (sizeof (std::complex <double >), cluster_id));
432+
433+ // memcpy(A_dsp, a, a_elems * sizeof(std::complex<double>));
434+ // memcpy(B_dsp, b, b_elems * sizeof(std::complex<double>));
435+ memcpy (C_dsp, c, c_elems * sizeof (std::complex <double >));
436+ *alp = *alpha;
437+ *bet = *beta;
438+
439+ mt_hthread_zgemm (CBLAS_ORDER::CblasColMajor,
440+ convertBLASTranspose (transa),
441+ convertBLASTranspose (transb),
442+ *m,
443+ *n,
444+ *k,
445+ alp,
446+ a,
447+ // A_dsp,
448+ *lda,
449+ b,
450+ // B_dsp,
451+ *ldb,
452+ bet,
453+ // c,
454+ C_dsp,
455+ *ldc,
456+ cluster_id);
457+ memcpy (c, C_dsp, c_elems * sizeof (std::complex <double >));
458+
459+ // free_ht(A_dsp);
460+ // free_ht(B_dsp);
461+ free_ht (C_dsp);
462+ free_ht (alp);
463+ free_ht (bet);
464+ }
465+
406466void cgemm_mth_ (const char * transa,
407467 const char * transb,
408468 const int * m,
@@ -443,6 +503,64 @@ void cgemm_mth_(const char* transa,
443503 free_ht (bet);
444504} // cgemm that needn't malloc_ht or free_ht
445505
506+ void cgemm_pack_mth_ (const char * transa,
507+ const char * transb,
508+ const int * m,
509+ const int * n,
510+ const int * k,
511+ const std::complex <float >* alpha,
512+ const std::complex <float >* a,
513+ const int * lda,
514+ const std::complex <float >* b,
515+ const int * ldb,
516+ const std::complex <float >* beta,
517+ std::complex <float >* c,
518+ const int * ldc,
519+ int cluster_id)
520+ {
521+ const bool transa_not = (transa[0 ] == ' N' || transa[0 ] == ' n' );
522+ const bool transb_not = (transb[0 ] == ' N' || transb[0 ] == ' n' );
523+ const size_t a_elems = static_cast <size_t >(*lda) * (transa_not ? static_cast <size_t >(*k) : static_cast <size_t >(*m));
524+ const size_t b_elems = static_cast <size_t >(*ldb) * (transb_not ? static_cast <size_t >(*n) : static_cast <size_t >(*k));
525+ const size_t c_elems = static_cast <size_t >(*ldc) * static_cast <size_t >(*n);
526+
527+ std::complex <float >* A_dsp = static_cast <std::complex <float >*>(malloc_ht (a_elems * sizeof (std::complex <float >), cluster_id));
528+ std::complex <float >* B_dsp = static_cast <std::complex <float >*>(malloc_ht (b_elems * sizeof (std::complex <float >), cluster_id));
529+ std::complex <float >* C_dsp = static_cast <std::complex <float >*>(malloc_ht (c_elems * sizeof (std::complex <float >), cluster_id));
530+ std::complex <float >* alp = static_cast <std::complex <float >*>(malloc_ht (sizeof (std::complex <float >), cluster_id));
531+ std::complex <float >* bet = static_cast <std::complex <float >*>(malloc_ht (sizeof (std::complex <float >), cluster_id));
532+
533+ memcpy (A_dsp, a, a_elems * sizeof (std::complex <float >));
534+ memcpy (B_dsp, b, b_elems * sizeof (std::complex <float >));
535+ memcpy (C_dsp, c, c_elems * sizeof (std::complex <float >));
536+ *alp = *alpha;
537+ *bet = *beta;
538+
539+ mt_hthread_cgemm (CBLAS_ORDER::CblasColMajor,
540+ convertBLASTranspose (transa),
541+ convertBLASTranspose (transb),
542+ *m,
543+ *n,
544+ *k,
545+ (const void *)alp,
546+ (const void *)A_dsp,
547+ *lda,
548+ (const void *)B_dsp,
549+ *ldb,
550+ (const void *)bet,
551+ (void *)C_dsp,
552+ *ldc,
553+ cluster_id);
554+
555+ memcpy (c, C_dsp, c_elems * sizeof (std::complex <float >));
556+
557+ free_ht (A_dsp);
558+ free_ht (B_dsp);
559+ free_ht (C_dsp);
560+ free_ht (alp);
561+ free_ht (bet);
562+ }
563+
446564void sgemv_mth_ (const char * transa,
447565 const int * m,
448566 const int * n,
@@ -570,4 +688,4 @@ void cgemv_mth_(const char* transa,
570688 free_ht (alp);
571689 free_ht (bet);
572690}
573- } // namespace mtfunc
691+ } // namespace mtfunc
0 commit comments