Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions kernels/optimized/blas/CPUBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
// clang-format off
extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
// clang-format on
#endif // ET_BUILD_FOR_APPLE
#endif // ET_BUILD_WITH_BLAS
Expand All @@ -25,6 +27,7 @@ namespace executorch {
namespace cpublas {

using executorch::aten::BFloat16;
using executorch::aten::complex;
using executorch::aten::Half;

#ifdef ET_BUILD_WITH_BLAS
Expand Down Expand Up @@ -197,5 +200,100 @@ void gemm(
}
// clang-format on

// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const complex<double> alpha,
const complex<double> *a, int64_t lda,
const complex<double> *b, int64_t ldb,
const complex<double> beta,
complex<double> *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
complex<double> alpha_ = alpha, beta_ = beta;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
zgemm_(
&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
c, &ldc_);
#else
using acc_type = utils::compute_dtype<complex<double>>;
gemm_impl(
transa, transb,
m, n, k,
static_cast<const acc_type>(alpha),
a, lda,
b, ldb,
static_cast<const acc_type>(beta),
c, ldc);
#endif
}
// clang-format on

// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const complex<float> alpha,
const complex<float> *a, int64_t lda,
const complex<float> *b, int64_t ldb,
const complex<float> beta,
complex<float> *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
complex<float> alpha_ = alpha, beta_ = beta;
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
char transa_ = to_blas(transa), transb_ = to_blas(transb);
cgemm_(
&transa_, &transb_,
&m_, &n_, &k_,
&alpha_,
a, &lda_,
b, &ldb_,
&beta_,
c, &ldc_);
#else
using acc_type = utils::compute_dtype<complex<float>>;
gemm_impl(
transa, transb,
m, n, k,
static_cast<const acc_type>(alpha),
a, lda,
b, ldb,
static_cast<const acc_type>(beta),
c, ldc);
#endif
}
// clang-format on

// clang-format off
void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const complex<Half> alpha,
const complex<Half> *a, int64_t lda,
const complex<Half> *b, int64_t ldb,
const complex<Half> beta,
complex<Half> *c, int64_t ldc) {
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);

using acc_type = utils::compute_dtype<complex<Half>>;
gemm_impl(
transa, transb,
m, n, k,
static_cast<const acc_type>(alpha),
a, lda,
b, ldb,
static_cast<const acc_type>(beta),
c, ldc);
}
// clang-format on

} // namespace cpublas
} // namespace executorch
27 changes: 27 additions & 0 deletions kernels/optimized/blas/CPUBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,33 @@ void gemm(
const executorch::aten::BFloat16 *b, int64_t ldb,
const executorch::aten::BFloat16 beta,
executorch::aten::BFloat16 *c, int64_t ldc);

void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const executorch::aten::complex<double> alpha,
const executorch::aten::complex<double> *a, int64_t lda,
const executorch::aten::complex<double> *b, int64_t ldb,
const executorch::aten::complex<double> beta,
executorch::aten::complex<double> *c, int64_t ldc);

void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const executorch::aten::complex<float> alpha,
const executorch::aten::complex<float> *a, int64_t lda,
const executorch::aten::complex<float> *b, int64_t ldb,
const executorch::aten::complex<float> beta,
executorch::aten::complex<float> *c, int64_t ldc);

void gemm(
TransposeType transa, TransposeType transb,
int64_t m, int64_t n, int64_t k,
const executorch::aten::complex<executorch::aten::Half> alpha,
const executorch::aten::complex<executorch::aten::Half> *a, int64_t lda,
const executorch::aten::complex<executorch::aten::Half> *b, int64_t ldb,
const executorch::aten::complex<executorch::aten::Half> beta,
executorch::aten::complex<executorch::aten::Half> *c, int64_t ldc);
// clang-format on

// clang-format off
Expand Down
2 changes: 1 addition & 1 deletion kernels/optimized/cpu/op_bmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ Tensor& opt_bmm_out(

if (executorch::runtime::isComplexType(self_type)) {
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
internal::bmm_out_impl<CTYPE>(self, mat2, out);
bmm_kernel<CTYPE>(self, mat2, out);
});
} else {
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {
Expand Down
18 changes: 11 additions & 7 deletions kernels/optimized/test/libblas_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@

#include <vector>

#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \
_<double, N>(); \
_<float, N>(); \
_<int64_t, N>(); \
_<uint8_t, N>(); \
_<int32_t, N>(); \
_<executorch::aten::BFloat16, N>();
#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \
_<double, N>(); \
_<float, N>(); \
_<int64_t, N>(); \
_<uint8_t, N>(); \
_<int32_t, N>(); \
_<executorch::aten::Half, N>(); \
_<executorch::aten::BFloat16, N>(); \
_<executorch::aten::complex<double>, N>(); \
_<executorch::aten::complex<float>, N>(); \
_<executorch::aten::complex<executorch::aten::Half>, N>();

namespace {

Expand Down
Loading