From 85fbe0d474729cd4fe508ee6a21c9cb6c037662e Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 8 Oct 2024 14:54:59 -0700 Subject: [PATCH] [ET][Portable][Build Size] REALHBF16 binary ops: maximum, minimum, mul - mul: 1.69 M -> 15 K - maximum: 353 K -> 11 K - minimum: 353 K -> 11 K Differential Revision: [D63909726](https://our.internmc.facebook.com/intern/diff/D63909726/) [ghstack-poisoned] --- kernels/portable/cpu/op_maximum.cpp | 98 +++------- kernels/portable/cpu/op_minimum.cpp | 98 +++------- kernels/portable/cpu/op_mul.cpp | 173 ++++++------------ .../kernels/portable/op_registration_util.bzl | 4 +- 4 files changed, 108 insertions(+), 265 deletions(-) diff --git a/kernels/portable/cpu/op_maximum.cpp b/kernels/portable/cpu/op_maximum.cpp index a2075151357..89eecf13ba5 100644 --- a/kernels/portable/cpu/op_maximum.cpp +++ b/kernels/portable/cpu/op_maximum.cpp @@ -7,98 +7,54 @@ */ #include -#include +#include #include #include namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MaximumInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MaximumInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = utils::max_override(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MaximumInner - : public ReportCanCastBug {}; - -} // namespace Tensor& maximum_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + static constexpr const char op_name[] = "maximum.out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, "maximum.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "maximum.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "maximum.out", CTYPE_OUT, [&]() { - MaximumInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out); - }); - }); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return utils::max_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_minimum.cpp b/kernels/portable/cpu/op_minimum.cpp index cfcb3fe9aca..11677facc7f 100644 --- a/kernels/portable/cpu/op_minimum.cpp +++ b/kernels/portable/cpu/op_minimum.cpp @@ -7,98 +7,54 @@ */ #include -#include +#include #include #include namespace torch { namespace executor { namespace native { -namespace { - -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MinimumInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MinimumInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = utils::min_override(a_casted, b_casted); - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MinimumInner - : public ReportCanCastBug {}; - -} // namespace Tensor& minimum_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + static constexpr const char op_name[] = "minimum.out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, "minimum.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "minimum.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "minimum.out", CTYPE_OUT, [&]() { - MinimumInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out); - }); - }); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return utils::min_override(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 3e3b601e726..e03b1efac85 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -7,8 +7,7 @@ */ #include -#include -#include +#include #include #include @@ -16,93 +15,46 @@ namespace torch { namespace executor { namespace native { -namespace { -template < - bool can_cast, - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner { - static void run(const Tensor& a, const Tensor& b, Tensor& out) { - apply_binary_elementwise_fn( - // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - CTYPE_IN value = a_casted * b_casted; - - return static_cast(value); - }, - a, - b, - out); - } -}; - -struct ReportCanCastBug { - static void run(const Tensor&, const Tensor&, Tensor&) { - ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); - } -}; - -template < - typename CTYPE_A, - typename CTYPE_B, - typename CTYPE_IN, - typename CTYPE_OUT> -struct MulInner - : public ReportCanCastBug {}; -} // namespace - Tensor& mul_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize ET_KERNEL_CHECK( ctx, - executorch::runtime::tensor_is_realhbbf16_type(out), + resize_to_broadcast_target_size(a, b, out) == Error::Ok, InvalidArgument, out); - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + static constexpr const char op_name[] = "mul.out"; - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { - using CTYPE_IN = typename torch::executor:: - promote_types::type; - ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { - MulInner< - can_cast::value, - CTYPE_A, - CTYPE_B, - CTYPE_IN, - CTYPE_OUT>::run(a, b, out); - }); - }); + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); }); return out; @@ -113,59 +65,36 @@ Tensor& mul_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + // Check Dim Order ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + // Resize ET_KERNEL_CHECK( - ctx, - executorch::runtime::tensor_is_realhbbf16_type(out), - InvalidArgument, - out); + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = - utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - - if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { - common_type = ScalarType::Float; - } - - ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REALB_TYPES( - common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHBBF16_TYPES( - out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B b_val; - utils::extract_scalar(b, &b_val); - CTYPE_IN b_casted = static_cast(b_val); - - apply_unary_map_fn( - [b_casted](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN value = a_casted * b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); + static constexpr const char op_name[] = "mul.Scalar_out"; + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { + return val_a * val_b; + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::SAME_AS_COMMON); }); return out; diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 42f4624d76d..9a91d710eba 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -757,6 +757,7 @@ ATEN_OPS = ( name = "op_maximum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", ], @@ -789,6 +790,7 @@ ATEN_OPS = ( name = "op_minimum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", ], @@ -804,7 +806,7 @@ ATEN_OPS = ( name = "op_mul", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ":scalar_utils", ], ),