diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index 9f33907b998..7d951a34372 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -6,72 +6,18 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -namespace { - -ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { - if (isFloatingType(a_type) && isFloatingType(b_type)) { - return promoteTypes(a_type, b_type); - } else if (isFloatingType(a_type)) { - return a_type; - } else if (isFloatingType(b_type)) { - return b_type; - } - return ScalarType::Float; -} - -} // namespace - Tensor& div_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Common Dtype - ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "div.out"; - - ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - return val_a / val_b; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::FLOATHBF16); - }); - - return out; + return div_out_impl(ctx, a, b, out); } Tensor& div_out_mode( @@ -80,84 +26,7 @@ Tensor& div_out_mode( const Tensor& b, exec_aten::optional mode, Tensor& out) { - if (!mode.has_value()) { - return div_out(ctx, a, b, out); - } - - auto mode_val = mode.value(); - - // Check mode - ET_KERNEL_CHECK( - ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); - - // Common Dtype - ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - common_type != ScalarType::Bool), - InvalidArgument, - out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "div.out_mode"; - - const bool mode_is_trunc = mode_val == "trunc"; - bool div_by_zero_error = false; - - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn( - [mode_is_trunc, &div_by_zero_error]( - const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_COMPUTE value = val_a / val_b; - if (mode_is_trunc) { - value = std::trunc(value); - } else { - // We established above that the mode is either trunc or floor, so - // it must be floor. - value = utils::floor_divide(val_a, val_b); - } - return value; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); - - ET_KERNEL_CHECK_MSG( - ctx, - !div_by_zero_error, - InvalidArgument, - out, - "Div mode operation encountered integer division by zero"); - - return out; + return div_out_mode_impl(ctx, a, b, mode, out); } Tensor& div_scalar_out( @@ -165,39 +34,7 @@ Tensor& div_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - // Common Dtype - ScalarType common_type = - isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float; - - // Check Common Dtype - ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "div.Scalar_out"; - - ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::SAME_AS_COMMON); - }); - - return out; + return div_scalar_out_impl(ctx, a, b, out); } Tensor& div_scalar_mode_out( @@ -206,72 +43,7 @@ Tensor& div_scalar_mode_out( const Scalar& b, exec_aten::optional mode, Tensor& out) { - if (!mode.has_value()) { - return div_scalar_out(ctx, a, b, out); - } - - auto mode_val = mode.value(); - - // Check mode - ET_KERNEL_CHECK( - ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); - - // Common Dtype - ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - - // Check Common Dtype - ET_KERNEL_CHECK( - ctx, - (canCast(common_type, out.scalar_type()) && - common_type != ScalarType::Bool), - InvalidArgument, - out); - - // Check for intergral division by zero - ET_KERNEL_CHECK_MSG( - ctx, - !(executorch::runtime::isIntegralType(common_type, true) && - utils::scalar_to(b) == 0), - InvalidArgument, - out, - "Div mode operation encountered integer division by zero"); - - // Check Dim Order - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - // Resize - ET_KERNEL_CHECK( - ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - - const bool mode_is_trunc = mode_val == "trunc"; - - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "div.Scalar_mode_out"; - - ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - const CTYPE_COMPUTE val_b = utils::scalar_to(b); - utils::apply_unitensor_elementwise_fn( - [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { - CTYPE_COMPUTE value = val_a / val_b; - if (mode_is_trunc) { - value = std::trunc(value); - } else { - value = utils::floor_divide(val_a, val_b); - } - return value; - }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - out, - utils::SupportedTensorDtypes::REALHBF16); - }); - - return out; + return div_scalar_mode_out_impl(ctx, a, b, mode, out); } } // namespace native diff --git a/kernels/portable/cpu/op_div_impl.cpp b/kernels/portable/cpu/op_div_impl.cpp new file mode 100644 index 00000000000..7e92ca0b24e --- /dev/null +++ b/kernels/portable/cpu/op_div_impl.cpp @@ -0,0 +1,281 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +namespace { + +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { + if (isFloatingType(a_type) && isFloatingType(b_type)) { + return promoteTypes(a_type, b_type); + } else if (isFloatingType(a_type)) { + return a_type; + } else if (isFloatingType(b_type)) { + return b_type; + } + return ScalarType::Float; +} + +} // namespace + +Tensor& div_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a / val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::FLOATHBF16); + }); + + return out; +} + +Tensor& div_out_mode_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + exec_aten::optional mode, + Tensor& out) { + if (!mode.has_value()) { + return div_out_impl(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out_mode"; + + const bool mode_is_trunc = mode_val == "trunc"; + bool div_by_zero_error = false; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [mode_is_trunc, &div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + // We established above that the mode is either trunc or floor, so + // it must be floor. + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); + + ET_KERNEL_CHECK_MSG( + ctx, + !div_by_zero_error, + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + + return out; +} + +Tensor& div_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float; + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); + }); + + return out; +} + +Tensor& div_scalar_mode_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + exec_aten::optional mode, + Tensor& out) { + if (!mode.has_value()) { + return div_scalar_out_impl(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); + + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check for intergral division by zero + ET_KERNEL_CHECK_MSG( + ctx, + !(executorch::runtime::isIntegralType(common_type, true) && + utils::scalar_to(b) == 0), + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + const bool mode_is_trunc = mode_val == "trunc"; + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_mode_out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/op_div_impl.h b/kernels/portable/cpu/op_div_impl.h new file mode 100644 index 00000000000..800886ae9be --- /dev/null +++ b/kernels/portable/cpu/op_div_impl.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +Tensor& div_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out); + +Tensor& div_out_mode_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + exec_aten::optional mode, + Tensor& out); + +Tensor& div_scalar_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out); + +Tensor& div_scalar_mode_out_impl( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + exec_aten::optional mode, + Tensor& out); +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/targets.bzl b/kernels/portable/cpu/targets.bzl index 20434459489..ab0e23fcfbb 100644 --- a/kernels/portable/cpu/targets.bzl +++ b/kernels/portable/cpu/targets.bzl @@ -87,3 +87,42 @@ def define_common_targets(): srcs = native.glob(["*.h"]), visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) + + runtime.cxx_library( + name = "op_div_impl", + srcs = ["op_div_impl.cpp"], + exported_headers = ["op_div_impl.h"], + visibility = [ + "//executorch/kernels/portable/cpu/...", + "//executorch/kernels/optimized/cpu/...", + "//executorch/kernels/portable/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", + "//executorch/kernels/portable/cpu/util:math_util", + "//executorch/kernels/portable/cpu:scalar_utils", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/core/exec_aten/util:scalar_type_util", + "//executorch/runtime/core/exec_aten/util:tensor_util", + "//executorch/runtime/kernel:kernel_includes", + ], + ) + + # The following will not participate in dtype selective build because + # they are refactored such to be used in optimized op implementations as well + # and we have not enabled selective build for optimized ops. + # To enable selective build for these ops, they must be copied over by + # selective build flow, however this results in such files, e.g. op_div_impl.cpp, + # getting compiled twice, once for selective build and once for optimized, and when + # put together they result in two copies of op_div_impl.o resulting in duplicate + # symbols + runtime.cxx_library( + name = "all_impl_deps", + deps = [ + "//executorch/kernels/portable/cpu:op_div_impl", + ], + visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], + ) diff --git a/shim/xplat/executorch/codegen/codegen.bzl b/shim/xplat/executorch/codegen/codegen.bzl index d989e33591f..50b495ca68e 100644 --- a/shim/xplat/executorch/codegen/codegen.bzl +++ b/shim/xplat/executorch/codegen/codegen.bzl @@ -395,7 +395,7 @@ def build_portable_lib(name, oplist_header_name, feature = None): srcs = portable_source_files, exported_headers = portable_header_files, exported_preprocessor_flags = ["-DEXECUTORCH_SELECTIVE_BUILD_DTYPE"], - deps = ["//executorch/kernels/portable/cpu/pattern:all_deps", "//executorch/kernels/portable/cpu/util:all_deps"], + deps = ["//executorch/kernels/portable/cpu/pattern:all_deps", "//executorch/kernels/portable/cpu/util:all_deps", "//executorch/kernels/portable/cpu:all_impl_deps"], # header_namespace is only available in xplat. See https://fburl.com/code/we2gvopk header_namespace = "executorch/kernels/portable/cpu", compiler_flags = ["-Wno-missing-prototypes"] + diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index f63932d4840..962e1fcfc11 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -475,11 +475,7 @@ ATEN_OPS = ( op_target( name = "op_div", deps = [ - "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:dtype_util", - "//executorch/kernels/portable/cpu/util:elementwise_util", - "//executorch/kernels/portable/cpu/util:math_util", - ":scalar_utils", + "//executorch/kernels/portable/cpu:op_div_impl", ], ), op_target( @@ -1275,4 +1271,4 @@ def portable_source_list(): def portable_header_list(): """All the header file names from //executorch/kernels/portable/cpu/""" - return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h"] + return ["selective_build.h", "scalar_utils.h", "math_constants.h", "vec_ops.h", "op_div_impl.h"]