diff --git a/kernels/optimized/blas/CPUBlas.cpp b/kernels/optimized/blas/CPUBlas.cpp index b948fb35488..51a4f1ca26b 100644 --- a/kernels/optimized/blas/CPUBlas.cpp +++ b/kernels/optimized/blas/CPUBlas.cpp @@ -17,6 +17,8 @@ // clang-format off extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc); extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc); +extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); +extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc); // clang-format on #endif // ET_BUILD_FOR_APPLE #endif // ET_BUILD_WITH_BLAS @@ -25,6 +27,7 @@ namespace executorch { namespace cpublas { using executorch::aten::BFloat16; +using executorch::aten::complex; using executorch::aten::Half; #ifdef ET_BUILD_WITH_BLAS @@ -197,5 +200,100 @@ void gemm( } // clang-format on +// clang-format off +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const complex alpha, + const complex *a, int64_t lda, + const complex *b, int64_t ldb, + const complex beta, + complex *c, int64_t ldc) { + normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); +#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) + complex alpha_ = alpha, beta_ = beta; + int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; + char transa_ = to_blas(transa), transb_ = to_blas(transb); + zgemm_( + &transa_, &transb_, + &m_, &n_, &k_, + &alpha_, + a, &lda_, + b, &ldb_, + &beta_, + c, &ldc_); +#else + using acc_type = utils::compute_dtype>; + gemm_impl( + transa, transb, + m, n, k, + static_cast(alpha), + a, lda, + b, ldb, + static_cast(beta), + c, ldc); +#endif +} +// clang-format on + +// clang-format off +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const complex alpha, + const complex *a, int64_t lda, + const complex *b, int64_t ldb, + const complex beta, + complex *c, int64_t ldc) { + normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + #if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE) + complex alpha_ = alpha, beta_ = beta; + int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc; + char transa_ = to_blas(transa), transb_ = to_blas(transb); + cgemm_( + &transa_, &transb_, + &m_, &n_, &k_, + &alpha_, + a, &lda_, + b, &ldb_, + &beta_, + c, &ldc_); +#else + using acc_type = utils::compute_dtype>; + gemm_impl( + transa, transb, + m, n, k, + static_cast(alpha), + a, lda, + b, ldb, + static_cast(beta), + c, ldc); +#endif +} +// clang-format on + +// clang-format off +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const complex alpha, + const complex *a, int64_t lda, + const complex *b, int64_t ldb, + const complex beta, + complex *c, int64_t ldc) { + normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + + using acc_type = utils::compute_dtype>; + gemm_impl( + transa, transb, + m, n, k, + static_cast(alpha), + a, lda, + b, ldb, + static_cast(beta), + c, ldc); +} +// clang-format on + } // namespace cpublas } // namespace executorch diff --git a/kernels/optimized/blas/CPUBlas.h b/kernels/optimized/blas/CPUBlas.h index d8517255f6c..28bf68ad750 100644 --- a/kernels/optimized/blas/CPUBlas.h +++ b/kernels/optimized/blas/CPUBlas.h @@ -111,6 +111,33 @@ void gemm( const executorch::aten::BFloat16 *b, int64_t ldb, const executorch::aten::BFloat16 beta, executorch::aten::BFloat16 *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const executorch::aten::complex alpha, + const executorch::aten::complex *a, int64_t lda, + const executorch::aten::complex *b, int64_t ldb, + const executorch::aten::complex beta, + executorch::aten::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const executorch::aten::complex alpha, + const executorch::aten::complex *a, int64_t lda, + const executorch::aten::complex *b, int64_t ldb, + const executorch::aten::complex beta, + executorch::aten::complex *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const executorch::aten::complex alpha, + const executorch::aten::complex *a, int64_t lda, + const executorch::aten::complex *b, int64_t ldb, + const executorch::aten::complex beta, + executorch::aten::complex *c, int64_t ldc); // clang-format on // clang-format off diff --git a/kernels/optimized/cpu/op_bmm.cpp b/kernels/optimized/cpu/op_bmm.cpp index 11697f9b0de..51e86d54e60 100644 --- a/kernels/optimized/cpu/op_bmm.cpp +++ b/kernels/optimized/cpu/op_bmm.cpp @@ -155,7 +155,7 @@ Tensor& opt_bmm_out( if (executorch::runtime::isComplexType(self_type)) { ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() { - internal::bmm_out_impl(self, mat2, out); + bmm_kernel(self, mat2, out); }); } else { ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() { diff --git a/kernels/optimized/test/libblas_test.cpp b/kernels/optimized/test/libblas_test.cpp index cb4d64d20c6..3c944214417 100644 --- a/kernels/optimized/test/libblas_test.cpp +++ b/kernels/optimized/test/libblas_test.cpp @@ -13,13 +13,17 @@ #include -#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \ - _(); \ - _(); \ - _(); \ - _(); \ - _(); \ - _(); +#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \ + _(); \ + _(); \ + _(); \ + _(); \ + _(); \ + _(); \ + _(); \ + _, N>(); \ + _, N>(); \ + _, N>(); namespace {