diff --git a/kernels/portable/cpu/op_pow.cpp b/kernels/portable/cpu/op_pow.cpp index c0130933135..81319b03d9f 100644 --- a/kernels/portable/cpu/op_pow.cpp +++ b/kernels/portable/cpu/op_pow.cpp @@ -9,101 +9,61 @@ #include #include -#include -#include -#include +#include #include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; - -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct PowInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct PowInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = std::pow(a_casted, b_casted); - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct PowInner - : public ReportCanCastBug {}; - -} // namespace - Tensor& pow_Tensor_Tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, common_type != exec_aten::ScalarType::Bool, InvalidArgument, out); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_REALHB_TYPES(a_type, ctx, "pow.Tensor_Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES( - b_type, ctx, "pow.Tensor_Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALH_TYPES( - out_type, ctx, "pow.Tensor_Tensor_out", CTYPE_OUT, [&]() { - PowInner< - !std::is_same::value && - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out); - }); - }); + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + if (compute_type != ScalarType::Float) { + compute_type = ScalarType::Double; + } + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "pow.Tensor_Tensor_out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return std::pow(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -114,51 +74,43 @@ Tensor& pow_Tensor_Scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + // Check Common Dtype + ET_KERNEL_CHECK( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, - out, - "Failed to resize output tensor."); + out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); - if (common_type == ScalarType::Half) { - common_type = ScalarType::Float; + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + if (compute_type != ScalarType::Float) { + compute_type = ScalarType::Double; } - ET_SWITCH_REALHB_TYPES(a_type, ctx, "pow.Tensor_Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES( - b_type, ctx, "pow.Tensor_Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "pow.Tensor_Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES( - out_type, ctx, "pow.Tensor_Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = std::pow(a_casted, b_casted); - - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "pow.Tensor_Scalar_out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -169,49 +121,43 @@ Tensor& pow_Scalar_out( const Scalar& a, const Tensor& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(b.scalar_type(), a); - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + // Check Common Dtype + ET_KERNEL_CHECK( ctx, - resize_tensor(out, b.sizes()) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + common_type != ScalarType::Bool), InvalidArgument, - out, - "Failed to resize output tensor."); + out); - ScalarType a_type = utils::get_scalar_dtype(a); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = - utils::promote_type_with_scalar(b_type, a, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(b, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, b.sizes()) == Error::Ok, InvalidArgument, out); - if (common_type == ScalarType::Half) { - common_type = ScalarType::Float; + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + if (compute_type != ScalarType::Float) { + compute_type = ScalarType::Double; } - ET_SWITCH_SCALAR_OBJ_TYPES(a_type, ctx, "pow.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "pow.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES(common_type, ctx, "pow.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALH_TYPES( - out_type, ctx, "pow.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_A val_a = 0; - utils::extract_scalar(a, &val_a); - - apply_unary_map_fn( - [val_a](const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = std::pow(a_casted, b_casted); - return static_cast(value); - }, - b.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "pow.Scalar_out"; + + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_a = utils::scalar_to(a); + utils::apply_unitensor_elementwise_fn( + [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); }, + ctx, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; diff --git a/kernels/portable/cpu/op_rsub.cpp b/kernels/portable/cpu/op_rsub.cpp index 5445c2df1f7..46af021efda 100644 --- a/kernels/portable/cpu/op_rsub.cpp +++ b/kernels/portable/cpu/op_rsub.cpp @@ -7,8 +7,7 @@ */ #include -#include -#include +#include #include namespace torch { @@ -21,57 +20,47 @@ Tensor& rsub_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + (common_type == out.scalar_type() && canCast(alpha_type, common_type)), InvalidArgument, - out, - "Failed to resize output tensor."); + out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ET_SWITCH_REAL_TYPES(a_type, ctx, "rsub.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES( - b_type, ctx, "rsub.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES( - common_type, ctx, "rsub.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES( - out_type, ctx, "rsub.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE_IN b_casted = static_cast(b_val); - CTYPE_IN alpha_val; - utils::extract_scalar(alpha, &alpha_val); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "rsub.Scalar_out"; - apply_unary_map_fn( - [b_casted, alpha_val](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN value = b_casted - alpha_val * a_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_b - val_alpha * val_a; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index d366c40b771..6217f82c3b1 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -7,64 +7,13 @@ */ #include -#include -#include -#include +#include #include #include namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner { - static void - run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted - alpha_val * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -template -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct SubInner - : public ReportCanCastBug {}; - -} // namespace Tensor& sub_out( KernelRuntimeContext& ctx, @@ -72,45 +21,52 @@ Tensor& sub_out( const Tensor& b, const Scalar& alpha, Tensor& out) { + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, + (canCast(common_type, out.scalar_type()) && + canCast(alpha_type, common_type)), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); - - constexpr auto name = "sub.out"; - - ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALH_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - CTYPE_IN alpha_val; - utils::extract_scalar(alpha, &alpha_val); - ET_SWITCH_REALH_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { - SubInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, alpha_val, out); - }); - }); + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_bitensor_elementwise_fn( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + b, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::REALHBF16); }); return out; @@ -122,67 +78,47 @@ Tensor& sub_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - (void)ctx; + ScalarType alpha_type = utils::get_scalar_dtype(alpha); - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( + // Check alpha type + ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out); + + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( ctx, - resize_tensor(out, a.sizes()) == Error::Ok, + (common_type == out.scalar_type() && canCast(alpha_type, common_type)), InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); + out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType alpha_type = utils::get_scalar_dtype(alpha); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - ET_KERNEL_CHECK(ctx, canCast(alpha_type, common_type), InvalidArgument, out); - - if (common_type == ScalarType::Half) { - common_type = ScalarType::Float; - } - - constexpr auto name = "sub.Scalar_out"; - - ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_REAL_TYPES(b_type, ctx, name, CTYPE_B, [&]() { - using CTYPE_IN = typename utils::promote_type_with_scalar_type< - CTYPE_A, - CTYPE_B, - /*half_to_float*/ true>::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE_IN b_casted = static_cast(b_val); - - CTYPE_IN alpha_val; - utils::extract_scalar(alpha, &alpha_val); - - using CTYPE_OUT = typename std::conditional< - std::is_same::value, - internal::F2, - CTYPE_IN>::type; - - apply_unary_map_fn( - [b_casted, alpha_val](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN value = a_casted - alpha_val * b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sub.Scalar_out"; + + ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + const CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); + utils::apply_unitensor_elementwise_fn( + [val_b, val_alpha](const CTYPE_COMPUTE val_a) { + return val_a - val_alpha * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/kernels/portable/cpu/util/elementwise_util.cpp b/kernels/portable/cpu/util/elementwise_util.cpp index bafb7e464c0..1086d8743c9 100644 --- a/kernels/portable/cpu/util/elementwise_util.cpp +++ b/kernels/portable/cpu/util/elementwise_util.cpp @@ -21,6 +21,8 @@ bool check_tensor_dtype( switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: return executorch::runtime::tensor_is_realhbbf16_type(t); + case SupportedTensorDtypes::REALHBF16: + return executorch::runtime::tensor_is_realhbf16_type(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return ( executorch::runtime::tensor_is_type(t, ScalarType::Bool) || diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 28b3e964dbf..788cfa85f47 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -69,6 +69,17 @@ load_to_common_fn get_load_to_common_fn_realhbbf16( return result; } +template +load_to_common_fn get_load_to_common_fn_realhbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + template load_to_common_fn get_load_to_common_fn_bool_or_byte( const Tensor& t) { @@ -129,6 +140,17 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { return result; } +template +store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( + const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + template store_common_to_tensor_fn get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { @@ -180,6 +202,7 @@ get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { enum class SupportedTensorDtypes { REALHBBF16, + REALHBF16, BOOL_OR_BYTE, SAME_AS_COMPUTE, SAME_AS_COMMON, @@ -194,6 +217,8 @@ load_to_common_fn get_load_to_common_fn( switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: return get_load_to_common_fn_realhbbf16(t); + case SupportedTensorDtypes::REALHBF16: + return get_load_to_common_fn_realhbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_load_to_common_fn_bool_or_byte(t); case SupportedTensorDtypes::SAME_AS_COMPUTE: @@ -212,6 +237,8 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( switch (dtypes) { case SupportedTensorDtypes::REALHBBF16: return get_store_common_to_tensor_fn_realhbbf16(t); + case SupportedTensorDtypes::REALHBF16: + return get_store_common_to_tensor_fn_realhbf16(t); case SupportedTensorDtypes::BOOL_OR_BYTE: return get_store_common_to_tensor_fn_bool_or_byte( t); diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index e25c5e36920..e500167fa04 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -456,6 +456,10 @@ inline bool isRealHBType(::executorch::aten::ScalarType t) { return (isRealHType(t) || t == ::executorch::aten::ScalarType::Bool); } +inline bool isRealHBF16Type(::executorch::aten::ScalarType t) { + return (isRealHType(t) || t == ::executorch::aten::ScalarType::BFloat16); +} + inline bool isRealHBBF16Type(::executorch::aten::ScalarType t) { return (isRealHBType(t) || t == ::executorch::aten::ScalarType::BFloat16); } @@ -1275,6 +1279,10 @@ inline ::executorch::aten::ScalarType promoteTypes( #define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_REAL_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) +#define ET_SWITCH_REALHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND2( \ + Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + #define ET_SWITCH_REALB_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_REAL_TYPES_AND(Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index 28395197bce..53b65e3f16a 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -517,6 +517,15 @@ inline bool tensor_is_realh_type(exec_aten::Tensor t) { return true; } +inline bool tensor_is_realhbf16_type(exec_aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + executorch::runtime::isRealHBF16Type(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + inline bool tensor_is_realhb_type(exec_aten::Tensor t) { ET_LOG_MSG_AND_RETURN_IF_FALSE( torch::executor::isRealHBType(t.scalar_type()), diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 9a91d710eba..475c656a41a 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -892,7 +892,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -985,8 +985,8 @@ ATEN_OPS = ( name = "op_rsub", deps = [ ":scalar_utils", - "//executorch/kernels/portable/cpu/util:functional_util", - "//executorch/kernels/portable/cpu/util:kernel_ops_util", + "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -1104,8 +1104,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", - "//executorch/kernels/portable/cpu/util:kernel_ops_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target(