diff --git a/kernels/portable/cpu/op_atan2.cpp b/kernels/portable/cpu/op_atan2.cpp index bae4106f997..19267ef49dd 100644 --- a/kernels/portable/cpu/op_atan2.cpp +++ b/kernels/portable/cpu/op_atan2.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include #include #include @@ -14,42 +14,58 @@ namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; +namespace { + +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { + if (isFloatingType(a_type) && isFloatingType(b_type)) { + return promoteTypes(a_type, b_type); + } else if (isFloatingType(a_type)) { + return a_type; + } else if (isFloatingType(b_type)) { + return b_type; + } + return ScalarType::Float; +} + +} // namespace Tensor& atan2_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes + // Common Dtype + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "atan2.out"; - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REALHB_TYPES(a_type, ctx, "atan2.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "atan2.out", CTYPE_B, [&]() { - ET_SWITCH_FLOATH_TYPES(out_type, ctx, "atan2.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_OUT casted_a = static_cast(val_a); - CTYPE_OUT casted_b = static_cast(val_b); - return static_cast(std::atan2(casted_a, casted_b)); - }, - a, - b, - out); - }); - }); + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return std::atan2(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index 9a1c34c0f14..9f33907b998 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -7,8 +7,7 @@ */ #include -#include -#include +#include #include #include #include @@ -20,7 +19,7 @@ namespace native { namespace { -ScalarType get_compute_type(ScalarType a_type, ScalarType b_type) { +ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { if (isFloatingType(a_type) && isFloatingType(b_type)) { return promoteTypes(a_type, b_type); } else if (isFloatingType(a_type)) { @@ -38,54 +37,38 @@ Tensor& div_out( const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); + // Common Dtype + ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type()); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - + // Resize ET_KERNEL_CHECK( ctx, - !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type), - InvalidArgument, - out); - ET_KERNEL_CHECK( - ctx, - !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type), + resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); - - ScalarType common_type = get_compute_type(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out", CTYPE_B, [&]() { - ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() { - ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted / b_casted; - - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a / val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::FLOATHBF16); }); return out; @@ -97,54 +80,83 @@ Tensor& div_out_mode( const Tensor& b, exec_aten::optional mode, Tensor& out) { + if (!mode.has_value()) { + return div_out(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = get_compute_type(a_type, b_type); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); - - // Allow casting float -> integral here - // non-bool -> bool is still disallowed + // Resize ET_KERNEL_CHECK( ctx, - !(common_type != ScalarType::Bool && out_type == ScalarType::Bool), + resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.out_mode", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out_mode", CTYPE_B, [&]() { - ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out_mode", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES(out_type, ctx, "div.out_mode", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [mode](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted / b_casted; - if (mode.has_value() && mode.value() == "trunc") { - value = std::trunc(value); - } else if (mode.has_value() && mode.value() == "floor") { - value = std::floor(value); - } - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.out_mode"; + + const bool mode_is_trunc = mode_val == "trunc"; + bool div_by_zero_error = false; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [mode_is_trunc, &div_by_zero_error]( + const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + if (is_integral_type::value) { + if (val_b == 0) { + div_by_zero_error = true; + return static_cast(0); + } + } + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + // We established above that the mode is either trunc or floor, so + // it must be floor. + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); + ET_KERNEL_CHECK_MSG( + ctx, + !div_by_zero_error, + InvalidArgument, + out, + "Div mode operation encountered integer division by zero"); + return out; } @@ -153,44 +165,36 @@ Tensor& div_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); + // Common Dtype + ScalarType common_type = + isFloatingType(a.scalar_type()) ? a.scalar_type() : ScalarType::Float; - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = isFloatingType(a_type) ? a_type : ScalarType::Float; - ScalarType out_type = out.scalar_type(); + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "div.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.Scalar_out", CTYPE, [&]() { - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE b_casted = static_cast(b_val); - - apply_unary_map_fn( - [b_casted](const CTYPE_A val_a) { - CTYPE a_casted = static_cast(val_a); - CTYPE value = a_casted / b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; @@ -202,48 +206,69 @@ Tensor& div_scalar_mode_out( const Scalar& b, exec_aten::optional mode, Tensor& out) { - (void)ctx; + if (!mode.has_value()) { + return div_scalar_out(ctx, a, b, out); + } + + auto mode_val = mode.value(); + + // Check mode + ET_KERNEL_CHECK( + ctx, mode_val == "trunc" || mode_val == "floor", InvalidArgument, out); - // Resize for dynamic shape + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), + InvalidArgument, + out); + + // Check for intergral division by zero ET_KERNEL_CHECK_MSG( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + !(executorch::runtime::isIntegralType(common_type, true) && + utils::scalar_to(b) == 0), InvalidArgument, out, - "Failed to resize output tensor."); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - - constexpr auto name = "div.Scalar_mode_out"; - - ET_SWITCH_REALB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES(out_type, ctx, name, CTYPE, [&]() { - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE b_casted = static_cast(b_val); - - apply_unary_map_fn( - [b_casted, mode](const CTYPE_A val_a) { - CTYPE a_casted = static_cast(val_a); - CTYPE value = a_casted / b_casted; - if (mode.has_value() && mode.value() == "trunc") { - value = std::trunc(value); - } else if (mode.has_value() && mode.value() == "floor") { - value = utils::floor_divide(a_casted, b_casted); - } - return value; - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); + "Div mode operation encountered integer division by zero"); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + const bool mode_is_trunc = mode_val == "trunc"; + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "div.Scalar_mode_out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b, mode_is_trunc](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE value = val_a / val_b; + if (mode_is_trunc) { + value = std::trunc(value); + } else { + value = utils::floor_divide(val_a, val_b); + } + return value; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; diff --git a/kernels/portable/cpu/util/elementwise_util.cpp b/kernels/portable/cpu/util/elementwise_util.cpp index 1086d8743c9..2988c604e8c 100644 --- a/kernels/portable/cpu/util/elementwise_util.cpp +++ b/kernels/portable/cpu/util/elementwise_util.cpp @@ -23,6 +23,8 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_realhbbf16_type(t); case SupportedTensorDtypes::REALHBF16: return executorch::runtime::tensor_is_realhbf16_type(t); + case SupportedTensorDtypes::FLOATHBF16: + return executorch::runtime::tensor_is_floating_type(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return ( executorch::runtime::tensor_is_type(t, ScalarType::Bool) || diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 788cfa85f47..ae211a4cf08 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -80,6 +80,17 @@ load_to_common_fn get_load_to_common_fn_realhbf16( return result; } +template +load_to_common_fn get_load_to_common_fn_floathbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_FLOATHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + template load_to_common_fn get_load_to_common_fn_bool_or_byte( const Tensor& t) { @@ -151,6 +162,17 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( return result; } +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_floathbf16(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_FLOATHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + template store_common_to_tensor_fn get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -203,6 +225,7 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { enum class SupportedTensorDtypes { REALHBBF16, REALHBF16, + FLOATHBF16, BOOL_OR_BYTE, SAME_AS_COMPUTE, SAME_AS_COMMON, @@ -219,6 +242,8 @@ load_to_common_fn get_load_to_common_fn( return get_load_to_common_fn_realhbbf16(t); case SupportedTensorDtypes::REALHBF16: return get_load_to_common_fn_realhbf16(t); + case SupportedTensorDtypes::FLOATHBF16: + return get_load_to_common_fn_realhbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_common_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -239,6 +264,8 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( return get_store_common_to_tensor_fn_realhbbf16(t); case SupportedTensorDtypes::REALHBF16: return get_store_common_to_tensor_fn_realhbf16(t); + case SupportedTensorDtypes::FLOATHBF16: + return get_store_common_to_tensor_fn_floathbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_common_to_tensor_fn_bool_or_byte( t); diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 475c656a41a..ba8a63a8b56 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -304,7 +304,7 @@ ATEN_OPS = ( name = "op_atan2", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -463,7 +463,7 @@ ATEN_OPS = ( name = "op_div", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", ],