diff --git a/kernels/optimized/cpu/op_mm.cpp b/kernels/optimized/cpu/op_mm.cpp new file mode 100644 index 00000000000..9131356aeb6 --- /dev/null +++ b/kernels/optimized/cpu/op_mm.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& opt_mm_out( + RuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + Tensor& out) { + ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + std::array output_sizes; + get_mm_out_target_size(in, mat2, output_sizes.data(), &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok, + InvalidArgument, + out); + + if (out.numel() == 0) { + return out; + } + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t n = in.size(0); + size_t k = in.size(1); + size_t m = mat2.size(1); + + // gemm expects column-major inputs and produces column-major + // output. So, we take advantage of the identity (A @ B).t() + // = B.t() @ A.t() here; row-major B is B.t() from gemm's + // column-major perspective, etc. + executorch::cpublas::gemm( + executorch::cpublas::TransposeType::NoTranspose, + executorch::cpublas::TransposeType::NoTranspose, + m, + n, + k, + static_cast(1), + mat2.const_data_ptr(), + m, + in.const_data_ptr(), + k, + static_cast(0), + out.mutable_data_ptr(), + m); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index e7bb2d36bf4..225498aa8d1 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -52,6 +52,13 @@ _OPTIMIZED_ATEN_OPS = ( ], }), ), + op_target( + name = "op_mm", + deps = [ + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", + ], + ), op_target( name = "op_mul", deps = [ diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 0d445deb3e8..7c2c4d35fd7 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -52,6 +52,11 @@ - arg_meta: null kernel_name: torch::executor::opt_le_tensor_out +- op: mm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_mm_out + - op: mul.out kernels: - arg_meta: null diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 7ae17c5237a..cd3ca556fe6 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -244,7 +244,7 @@ def define_common_targets(): _common_op_test("op_mean_test", ["aten", "portable"]) _common_op_test("op_min_test", ["aten", "portable"]) _common_op_test("op_minimum_test", ["aten", "portable"]) - _common_op_test("op_mm_test", ["aten", "portable"]) + _common_op_test("op_mm_test", ["aten", "portable", "optimized"]) _common_op_test("op_mul_test", ["aten", "portable", "optimized"]) _common_op_test("op_narrow_copy_test", ["aten", "portable"]) _common_op_test("op_native_batch_norm_test", ["aten", "portable"])