diff --git a/kernels/optimized/cpu/op_bmm.cpp b/kernels/optimized/cpu/op_bmm.cpp index 5e7fa1dd839..11697f9b0de 100644 --- a/kernels/optimized/cpu/op_bmm.cpp +++ b/kernels/optimized/cpu/op_bmm.cpp @@ -6,9 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include - #include +#include +#include // Performs a batch matrix-matrix product of matrices stored in input and mat2. @@ -136,33 +136,32 @@ Error resize_out_tensor(const Tensor& self, const Tensor& mat2, Tensor& out) { // bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) Tensor& opt_bmm_out( - KernelRuntimeContext& context, + KernelRuntimeContext& ctx, const Tensor& self, const Tensor& mat2, Tensor& out) { - (void)context; + (void)ctx; ET_KERNEL_CHECK( - context, + ctx, resize_out_tensor(self, mat2, out) == Error::Ok, InvalidArgument, out); ET_KERNEL_CHECK( - context, check_bmm_out_args(self, mat2, out), InvalidArgument, out); - -#define BMM_TENSOR(ctype, dtype) \ - case ScalarType::dtype: \ - bmm_kernel(self, mat2, out); \ - break; - - auto scalar_type = self.scalar_type(); - switch (scalar_type) { - ET_FORALL_REAL_TYPES_AND(Half, BMM_TENSOR) - default: - ET_CHECK_MSG( - false, "Unhandled dtype %" PRId8, static_cast(scalar_type)); + ctx, check_bmm_out_args(self, mat2, out), InvalidArgument, out); + + constexpr auto name = "bmm.out"; + auto self_type = self.scalar_type(); + + if (executorch::runtime::isComplexType(self_type)) { + ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() { + internal::bmm_out_impl(self, mat2, out); + }); + } else { + ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() { + bmm_kernel(self, mat2, out); + }); } -#undef BMM_TENSOR return out; } diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index c9da2584f08..bf24e4de49c 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -15,6 +15,7 @@ _OPTIMIZED_ATEN_OPS = ( name = "op_bmm", deps = [ "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", ], ), op_target( diff --git a/kernels/portable/cpu/op_bmm.cpp b/kernels/portable/cpu/op_bmm.cpp index b9f9d4f2c94..a887cd3c926 100644 --- a/kernels/portable/cpu/op_bmm.cpp +++ b/kernels/portable/cpu/op_bmm.cpp @@ -7,7 +7,6 @@ */ #include -#include #include namespace torch { @@ -37,26 +36,19 @@ Tensor& bmm_out( InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Half, in.scalar_type(), ctx, "bmm.out", CTYPE, [&]() { - const CTYPE* in_data = in.const_data_ptr(); - const CTYPE* mat2_data = mat2.const_data_ptr(); - CTYPE* out_data = out.mutable_data_ptr(); + constexpr auto name = "bmm.out"; - int64_t batch_size = in.size(0); - int64_t m = in.size(1); - int64_t n = in.size(2); - int64_t p = mat2.size(2); + auto in_type = in.scalar_type(); - for (int i = 0; i < batch_size; ++i) { - const CTYPE* in_data_offset = in_data + i * m * n; - const CTYPE* mat2_data_offset = mat2_data + i * n * p; - CTYPE* out_data_offset = out_data + i * m * p; - - vec_matmul( - out_data_offset, in_data_offset, mat2_data_offset, m, n, p); - } - }); + if (executorch::runtime::isComplexType(in_type)) { + ET_SWITCH_COMPLEXH_TYPES(in_type, ctx, name, CTYPE, [&]() { + internal::bmm_out_impl(in, mat2, out); + }); + } else { + ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() { + internal::bmm_out_impl(in, mat2, out); + }); + } return out; } diff --git a/kernels/portable/cpu/util/matmul_ops_util.h b/kernels/portable/cpu/util/matmul_ops_util.h index d2991868e95..2d2799eaa59 100644 --- a/kernels/portable/cpu/util/matmul_ops_util.h +++ b/kernels/portable/cpu/util/matmul_ops_util.h @@ -45,5 +45,36 @@ void get_linear_out_target_size( Tensor::SizesType* out_sizes, size_t* out_ndim); +namespace internal { + +template +void bmm_out_impl(const Tensor& in, const Tensor& mat2, Tensor& out) { + const CTYPE* in_data = in.const_data_ptr(); + const CTYPE* mat2_data = mat2.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + + int64_t batch_size = in.size(0); + int64_t m = in.size(1); + int64_t n = in.size(2); + int64_t p = mat2.size(2); + + for (int b = 0; b < batch_size; ++b) { + const CTYPE* in_data_offset = in_data + b * m * n; + const CTYPE* mat2_data_offset = mat2_data + b * n * p; + CTYPE* out_data_offset = out_data + b * m * p; + + for (const auto i : c10::irange(m)) { + for (const auto j : c10::irange(p)) { + CTYPE sum = static_cast(0.0); + for (const auto k : c10::irange(n)) { + sum += in_data_offset[i * n + k] * mat2_data_offset[k * p + j]; + } + out_data_offset[i * p + j] = sum; + } + } + } +} + +} // namespace internal } // namespace executor } // namespace torch diff --git a/kernels/test/op_bmm_test.cpp b/kernels/test/op_bmm_test.cpp index 88671467f46..70a5f37946d 100644 --- a/kernels/test/op_bmm_test.cpp +++ b/kernels/test/op_bmm_test.cpp @@ -43,6 +43,61 @@ class OpBmmOutTest : public OperatorTest { EXPECT_TENSOR_EQ(out, expected); } + + template + void test_complex_dtype() { + TensorFactory tf; + Tensor x = tf.make( + {2, 2, 3}, + {CTYPE(1, 1), + CTYPE(2, 2), + CTYPE(3, 3), + CTYPE(4, 4), + CTYPE(5, 5), + CTYPE(6, 6), + CTYPE(7, 7), + CTYPE(8, 8), + CTYPE(9, 9), + CTYPE(10, 10), + CTYPE(11, 11), + CTYPE(12, 12)}); + Tensor y = tf.make( + {2, 3, 2}, + {CTYPE(2, 1), + CTYPE(4, 2), + CTYPE(6, 3), + CTYPE(8, 4), + CTYPE(10, 5), + CTYPE(12, 6), + CTYPE(14, 7), + CTYPE(16, 8), + CTYPE(18, 9), + CTYPE(20, 10), + CTYPE(22, 11), + CTYPE(24, 12)}); + Tensor out = tf.make( + {2, 2, 2}, + {CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0), + CTYPE(0, 0)}); + Tensor expected = tf.make( + {2, 2, 2}, + {CTYPE(22, 66), + CTYPE(28, 84), + CTYPE(49, 147), + CTYPE(64, 192), + CTYPE(220, 660), + CTYPE(244, 732), + CTYPE(301, 903), + CTYPE(334, 1002)}); + op_bmm_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected); + } }; TEST_F(OpBmmOutTest, OutputDim) { @@ -132,7 +187,7 @@ TEST_F(OpBmmOutTest, OutputDimFloat) { /// A generic smoke test that works for any dtype that supports ones() and /// zeros(). -TEST_F(OpBmmOutTest, AllDtypesSupported) { +TEST_F(OpBmmOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); ET_FORALL_REAL_TYPES(TEST_ENTRY); #undef TEST_ENTRY @@ -141,6 +196,16 @@ TEST_F(OpBmmOutTest, AllDtypesSupported) { // for those types. } +TEST_F(OpBmmOutTest, AllComplexDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) test_complex_dtype(); + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + ET_FORALL_COMPLEX_TYPES(TEST_ENTRY); + } else { + ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY); + } +#undef TEST_ENTRY +} + TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) { TensorFactory tf; diff --git a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl index a1ffdc1eed3..06f9a452935 100644 --- a/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -372,7 +372,6 @@ ATEN_OPS = ( name = "op_bmm", deps = [ "//executorch/kernels/portable/cpu/util:matmul_ops_util", - ":vec_ops", ], ), op_target(