diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 1ee73d342ca..313f578dc2f 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -6,10 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include +#include namespace torch { namespace executor { @@ -20,52 +17,7 @@ Tensor& mul_out( const Tensor& a, const Tensor& b, Tensor& out) { - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "mul.out"; - - ET_KERNEL_CHECK( - ctx, - (executorch::runtime::isRealType(compute_type) || - compute_type == ScalarType::Bool), - InvalidArgument, - out); - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a * val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBBF16); - }); - - return out; + return mul_out_impl(ctx, a, b, out); } Tensor& mul_scalar_out( @@ -73,38 +25,7 @@ Tensor& mul_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - // Common Dtype - ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - - // Check Common Dtype - ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "mul.Scalar_out"; - - ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - - return out; + return mul_scalar_out_impl(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_mul_impl.cpp b/kernels/portable/cpu/op_mul_impl.cpp new file mode 100644 index 00000000000..d52bfad6ac1 --- /dev/null +++ b/kernels/portable/cpu/op_mul_impl.cpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& mul_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.out"; + + ET_KERNEL_CHECK( + ctx, + (executorch::runtime::isRealType(compute_type) || + compute_type == ScalarType::Bool), + InvalidArgument, + out); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + + return out; +} + +Tensor& mul_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.Scalar_out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_mul_impl.h b/kernels/portable/cpu/op_mul_impl.h new file mode 100644 index 00000000000..f3d616e639c --- /dev/null +++ b/kernels/portable/cpu/op_mul_impl.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& mul_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out); + +Tensor& mul_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out); + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index ab0e23fcfbb..e6cb5c02ff4 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -111,6 +111,26 @@ def define_common_targets(): ], ) + runtime.cxx_library( + name = "op_mul_impl", + srcs = ["op_mul_impl.cpp"], + exported_headers = ["op_mul_impl.h"], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/kernels/optimized/cpu/...", + "//executorch/kernels/portable/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:math_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/kernel:kernel_includes", + ], + ) + # The following will not participate in dtype selective build because # they are refactored such to be used in optimized op implementations as well # and we have not enabled selective build for optimized ops. @@ -123,6 +143,7 @@ def define_common_targets(): name = "all_impl_deps", deps = [ "//executorch/kernels/portable/cpu:op_div_impl", + "//executorch/kernels/portable/cpu:op_mul_impl", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 962e1fcfc11..b1a612a7e74 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -842,10 +842,7 @@ ATEN_OPS = ( op_target( name = "op_mul", deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:dtype_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", - ":scalar_utils", + "//executorch/kernels/portable/cpu:op_mul_impl", ], ), op_target( @@ -1271,4 +1268,4 @@ def portable_source_list(): def portable_header_list(): """All the header file names from //executorch/kernels/portable/cpu/""" - return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h"] + return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h", "op_mul_impl.h"]