diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index c9c6eda9324..adb9d4ea723 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -22,14 +22,6 @@ Tensor& add_out( const Tensor& b, const Scalar& alpha, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(b) && - executorch::runtime::tensor_is_realhbbf16_type(out)), - InvalidArgument, - out); - // Common Dtype ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); @@ -64,6 +56,7 @@ Tensor& add_out( [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { return val_a + val_alpha * val_b; }, + ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, @@ -81,13 +74,6 @@ Tensor& add_scalar_out( const Scalar& b, const Scalar& alpha, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - (executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(out)), - InvalidArgument, - out); - // Common Dtype ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); @@ -120,6 +106,7 @@ Tensor& add_scalar_out( CTYPE_COMPUTE val_alpha = utils::scalar_to(alpha); return val_a + val_alpha * val_b; }, + ctx, a, utils::SupportedTensorDtypes::REALHBBF16, out, diff --git a/kernels/portable/cpu/op_clamp.cpp b/kernels/portable/cpu/op_clamp.cpp index 35218cbb599..3f282a4473a 100644 --- a/kernels/portable/cpu/op_clamp.cpp +++ b/kernels/portable/cpu/op_clamp.cpp @@ -73,13 +73,6 @@ Tensor& clamp_out( const exec_aten::optional& min_opt, const exec_aten::optional& max_opt, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - (executorch::runtime::tensor_is_realhbbf16_type(in) && - executorch::runtime::tensor_is_realhbbf16_type(out)), - InvalidArgument, - out); - bool has_min = min_opt.has_value(); bool has_max = max_opt.has_value(); @@ -154,6 +147,7 @@ Tensor& clamp_out( } return val_out; }, + ctx, in, utils::SupportedTensorDtypes::REALHBBF16, out, @@ -182,15 +176,6 @@ Tensor& clamp_tensor_out( const Tensor& min = has_min ? min_opt.value() : in; const Tensor& max = has_max ? max_opt.value() : in; - ET_KERNEL_CHECK( - ctx, - (executorch::runtime::tensor_is_realhbbf16_type(in) && - executorch::runtime::tensor_is_realhbbf16_type(min) && - executorch::runtime::tensor_is_realhbbf16_type(max) && - executorch::runtime::tensor_is_realhbbf16_type(out)), - InvalidArgument, - out); - // Common Dtype ScalarType common_type = in.scalar_type(); if (has_min) { @@ -239,6 +224,7 @@ Tensor& clamp_tensor_out( } return val_out; }, + ctx, in, utils::SupportedTensorDtypes::REALHBBF16, min, diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 983cbc8cbb9..b455c45c2d1 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -19,16 +19,6 @@ Tensor& where_out( const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - ((cond.scalar_type() == ScalarType::Bool || - cond.scalar_type() == ScalarType::Byte) && - executorch::runtime::tensor_is_realhbbf16_type(a) && - executorch::runtime::tensor_is_realhbbf16_type(b) && - executorch::runtime::tensor_is_realhbbf16_type(out)), - InvalidArgument, - out); - // Common Dtype ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); @@ -57,6 +47,7 @@ Tensor& where_out( [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b, const CTYPE_COMPUTE val_c) { return val_c ? val_a : val_b; }, + ctx, a, utils::SupportedTensorDtypes::REALHBBF16, b, diff --git a/kernels/portable/cpu/util/elementwise_util.cpp b/kernels/portable/cpu/util/elementwise_util.cpp new file mode 100644 index 00000000000..bafb7e464c0 --- /dev/null +++ b/kernels/portable/cpu/util/elementwise_util.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { +namespace native { +namespace utils { +namespace internal { + +bool check_tensor_dtype( + const Tensor t, + SupportedTensorDtypes dtypes, + const ScalarType compute_type) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return executorch::runtime::tensor_is_realhbbf16_type(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return ( + executorch::runtime::tensor_is_type(t, ScalarType::Bool) || + executorch::runtime::tensor_is_type(t, ScalarType::Byte)); + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return executorch::runtime::tensor_is_type(t, compute_type); + case SupportedTensorDtypes::SAME_AS_COMMON: { + if (compute_type == ScalarType::Float) { + return ( + executorch::runtime::tensor_is_type(t, ScalarType::Float) || + executorch::runtime::tensor_is_type(t, ScalarType::Half) || + executorch::runtime::tensor_is_type(t, ScalarType::BFloat16)); + } else { + return executorch::runtime::tensor_is_type(t, compute_type); + } + } + } + ET_CHECK(false); + return false; +} + +} // namespace internal +} // namespace utils +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 52ad6fca116..28b3e964dbf 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -229,15 +229,29 @@ store_common_to_tensor_fn get_store_common_to_tensor_fn( return nullptr; } +bool check_tensor_dtype( + const Tensor t, + SupportedTensorDtypes dtypes, + const ScalarType compute_type); + } // namespace internal template inline void apply_unitensor_elementwise_fn( const Op& compute_fun, + KernelRuntimeContext& ctx, const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { + constexpr auto compute_type = CppTypeToScalarType::value; + + ET_KERNEL_CHECK( + ctx, + (internal::check_tensor_dtype(a, a_dtypes, compute_type) && + internal::check_tensor_dtype(out, out_dtypes, compute_type)), + InvalidArgument, ); + const auto load_a_to_common = internal::get_load_to_common_fn(a, a_dtypes); const auto store_common_to_out = @@ -263,12 +277,22 @@ inline void apply_unitensor_elementwise_fn( template inline void apply_bitensor_elementwise_fn( const Op& compute_fun, + KernelRuntimeContext& ctx, const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& b, SupportedTensorDtypes b_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { + constexpr auto compute_type = CppTypeToScalarType::value; + + ET_KERNEL_CHECK( + ctx, + (internal::check_tensor_dtype(a, a_dtypes, compute_type) && + internal::check_tensor_dtype(b, b_dtypes, compute_type) && + internal::check_tensor_dtype(out, out_dtypes, compute_type)), + InvalidArgument, ); + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted); @@ -312,9 +336,9 @@ inline void apply_bitensor_elementwise_fn( } /** - * Useful for tri-tensor elementwise operators. For each element of the inputs, - * perform a computation and write to the corresponding element of the output. - * Tensor broadcasting is applied wherever it is required. + * Useful for tri-tensor elementwise operators. For each element of the + * inputs, perform a computation and write to the corresponding element of the + * output. Tensor broadcasting is applied wherever it is required. * * In order to mitigate build time cost (straightforwardly |CTYPE_A| * * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun @@ -334,6 +358,7 @@ inline void apply_bitensor_elementwise_fn( template inline void apply_tritensor_elementwise_fn( const Op& compute_fun, + KernelRuntimeContext& ctx, const Tensor& a, SupportedTensorDtypes a_dtypes, const Tensor& b, @@ -342,6 +367,16 @@ inline void apply_tritensor_elementwise_fn( SupportedTensorDtypes c_dtypes, const Tensor& out, SupportedTensorDtypes out_dtypes) { + constexpr auto compute_type = CppTypeToScalarType::value; + + ET_KERNEL_CHECK( + ctx, + (internal::check_tensor_dtype(a, a_dtypes, compute_type) && + internal::check_tensor_dtype(b, b_dtypes, compute_type) && + internal::check_tensor_dtype(c, c_dtypes, compute_type) && + internal::check_tensor_dtype(out, out_dtypes, compute_type)), + InvalidArgument, ); + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); const bool c_is_broadcasted = !out.sizes().equals(c.sizes()); diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index 2285206728f..b182ce192b5 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -80,6 +80,7 @@ def define_common_targets(): runtime.cxx_library( name = "elementwise_util", + srcs = ["elementwise_util.cpp"], exported_headers = [ "elementwise_util.h", ], diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index b303feafd46..28395197bce 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -469,6 +469,16 @@ inline bool tensor_is_bool_type(exec_aten::Tensor t) { return true; } +inline bool tensor_is_type(exec_aten::Tensor t, exec_aten::ScalarType dtype) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == dtype, + "Expected to find %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(t.scalar_type())); + + return true; +} + inline bool tensor_is_integral_type( exec_aten::Tensor t, bool includeBool = false) {