diff --git a/kernels/portable/cpu/op_floor_divide.cpp b/kernels/portable/cpu/op_floor_divide.cpp index c6a7902b3d2..85eb612ea1e 100644 --- a/kernels/portable/cpu/op_floor_divide.cpp +++ b/kernels/portable/cpu/op_floor_divide.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include #include #include @@ -17,106 +17,61 @@ namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FloorDivideInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FloorDivideInner { - static void - run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = utils::floor_divide(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&, bool&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FloorDivideInner - : public ReportCanCastBug {}; - -} // namespace - Tensor& floor_divide_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); - + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "floor_divide.out"; - auto div_by_zero_error = false; + bool div_by_zero_error = false; - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "floor_divide.out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "floor_divide.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REAL_TYPES( - out_type, ctx, "floor_divide.out", CTYPE_OUT, [&]() { - FloorDivideInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out, div_by_zero_error); - }); - }); - }); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [&div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + return utils::floor_divide(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/op_fmod.cpp b/kernels/portable/cpu/op_fmod.cpp index 98db14cc174..1e8cba0f1ae 100644 --- a/kernels/portable/cpu/op_fmod.cpp +++ b/kernels/portable/cpu/op_fmod.cpp @@ -9,113 +9,73 @@ #include #include -#include -#include +#include #include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; - -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FmodInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FmodInner { - static void - run(const Tensor& a, const Tensor& b, Tensor& out, bool& div_by_zero_error) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [&div_by_zero_error](const CTYPE_A val_a, const CTYPE_B val_b) { - if (is_integral_type::value) { - if (val_b == 0) { - div_by_zero_error = true; - return static_cast(0); - } - } - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = std::fmod(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&, bool&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct FmodInner - : public ReportCanCastBug {}; - -} // namespace - Tensor& fmod_Tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - auto div_by_zero_error = false; - - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "fmod.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "fmod.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REAL_TYPES( - out_type, ctx, "fmod.Tensor_out", CTYPE_OUT, [&]() { - FmodInner< - !std::is_same::value && - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out, div_by_zero_error); - }); - }); - }); + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + if (compute_type != ScalarType::Float) { + compute_type = ScalarType::Double; + } + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "fmod.Tensor_out"; + + bool div_by_zero_error = false; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [&div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + CTYPE_COMPUTE value = 0; + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return value; + } + } + value = std::fmod(val_a, val_b); + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); ET_KERNEL_CHECK_MSG( ctx, @@ -132,71 +92,56 @@ Tensor& fmod_Scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - // Resize for dynamic shape + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check for intergral division by zero ET_KERNEL_CHECK_MSG( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + !(executorch::runtime::isIntegralType(common_type, true) && + utils::scalar_to(b) == 0), InvalidArgument, out, - "Failed to resize output tensor."); + "Fmod operation encountered integer division by zero"); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + if (compute_type != ScalarType::Float) { + compute_type = ScalarType::Double; + } - // Check for integer division by zero - if (isIntegralType(common_type, /*includeBool=*/true)) { - auto is_zero = false; - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "fmod.Scalar_out", CTYPE_B, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - is_zero = (val_b == 0); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "fmod.Scalar_out"; - ET_KERNEL_CHECK_MSG( + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE value = std::fmod(val_a, val_b); + return value; + }, ctx, - !is_zero, - InvalidArgument, + a, + utils::SupportedTensorDtypes::REALHBBF16, out, - "Fmod operation encountered integer division by zero"); - } - - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "fmod.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES( - b_type, ctx, "fmod.Scalar_out", CTYPE_B, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - ET_SWITCH_REAL_TYPES( - common_type, ctx, "fmod.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "fmod.Scalar_out", CTYPE_OUT, [&]() { - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = std::fmod(a_casted, b_casted); - - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); + utils::SupportedTensorDtypes::REALHBF16); + }); return out; } diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index 8f25a72167a..d34c34a0380 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -9,8 +9,7 @@ #include #include -#include -#include +#include #include #include @@ -18,96 +17,70 @@ namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; - -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct RemainderInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct RemainderInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = utils::remainder_override(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct RemainderInner - : public ReportCanCastBug {}; - -} // namespace Tensor& remainder_Tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); - // Determine output size and resize for dynamic shapes + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "remainder.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "remainder.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REAL_TYPES( - out_type, ctx, "remainder.Tensor_out", CTYPE_OUT, [&]() { - RemainderInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out); - }); - }); - }); + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "remainder.Tensor_out"; + + bool div_by_zero_error = false; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [&div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + CTYPE_COMPUTE value = 0; + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return value; + } + } + value = utils::remainder_override(val_a, val_b); + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); + + ET_KERNEL_CHECK_MSG( + ctx, + !div_by_zero_error, + InvalidArgument, + out, + "Remainder operation encountered integer division by zero"); return out; } @@ -117,58 +90,52 @@ Tensor& remainder_Scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); - // Resize for dynamic shape + // Check for intergral division by zero ET_KERNEL_CHECK_MSG( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + !(executorch::runtime::isIntegralType(common_type, true) && + utils::scalar_to(b) == 0), InvalidArgument, out, - "Failed to resize output tensor."); + "Remainder operation encountered integer division by zero"); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "remainder.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES( - b_type, ctx, "remainder.Scalar_out", CTYPE_B, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - ET_SWITCH_REAL_TYPES( - common_type, ctx, "remainder.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, - ctx, - "remainder.Scalar_out", - CTYPE_OUT, - [&]() { - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = - static_cast(val_a); - CTYPE_IN b_casted = - static_cast(val_b); - CTYPE_IN value = utils::remainder_override( - a_casted, b_casted); - - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "remainder.Scalar_out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { + return utils::remainder_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); + }); return out; } diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index ba8a63a8b56..26e16b8fecc 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -538,6 +538,7 @@ ATEN_OPS = ( name = "op_floor_divide", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ], ), @@ -546,7 +547,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -938,7 +939,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ], ),