diff --git a/kernels/portable/cpu/util/elementwise_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp similarity index 96% rename from kernels/portable/cpu/util/elementwise_util.cpp rename to kernels/portable/cpu/util/dtype_util.cpp index ef94c599b70..299910da746 100644 --- a/kernels/portable/cpu/util/elementwise_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -6,7 +6,7 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include namespace torch { namespace executor { diff --git a/kernels/portable/cpu/util/dtype_util.h b/kernels/portable/cpu/util/dtype_util.h new file mode 100644 index 00000000000..2bbd5de4577 --- /dev/null +++ b/kernels/portable/cpu/util/dtype_util.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace torch { +namespace executor { +namespace native { +namespace utils { +namespace internal { + +template +To load_and_convert(const void* fromPtr) { + return static_cast(*reinterpret_cast(fromPtr)); +} + +template +void convert_and_store(From f, void* dst) { + *reinterpret_cast(dst) = static_cast(f); +} + +template +using load_to_common_fn = CTYPE_COMMON (*)(const void*); + +template +load_to_common_fn get_load_to_common_fn_realhbbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_realhbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_REALHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_floathbf16( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_FLOATHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_intb(const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_INT_TYPES_AND( + Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_bool_or_byte( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_TWO_TYPES( + Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template +load_to_common_fn get_load_to_common_fn_same_as_compute( + const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::load_and_convert; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + CTYPE_COMMON (*result)(const void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { + result = internal::load_and_convert; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +load_to_common_fn get_load_to_common_fn_same_as_common( + const Tensor& t) { + return get_load_to_common_fn_same_as_compute(t); +} + +template +using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALHBBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( + const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_REALHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_floathbf16(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_FLOATHBF16_TYPES( + t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn get_store_common_to_tensor_fn_intb( + const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_INT_TYPES_AND( + Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_TWO_TYPES( + Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { + constexpr auto common_scalar_type = CppTypeToScalarType::value; + ET_CHECK_MSG( + t.scalar_type() == common_scalar_type, + "Unhandled dtype %s for %s", + ::executorch::runtime::toString(common_scalar_type), + op_name); + return internal::convert_and_store; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + void (*result)(CTYPE_COMMON, void*) = nullptr; + ET_SWITCH_THREE_TYPES( + Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { + result = internal::convert_and_store; + }); + return result; +} + +template < + typename CTYPE_COMMON, + const char* op_name, + std::enable_if_t, bool> = true> +store_common_to_tensor_fn +get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { + return get_store_common_to_tensor_fn_same_as_compute( + t); +} + +} // namespace internal + +enum class SupportedTensorDtypes { + REALHBBF16, + REALHBF16, + FLOATHBF16, + INTB, + BOOL_OR_BYTE, + SAME_AS_COMPUTE, + SAME_AS_COMMON, +}; + +namespace internal { + +template +load_to_common_fn get_load_to_common_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_load_to_common_fn_realhbbf16(t); + case SupportedTensorDtypes::REALHBF16: + return get_load_to_common_fn_realhbf16(t); + case SupportedTensorDtypes::FLOATHBF16: + return get_load_to_common_fn_realhbf16(t); + case SupportedTensorDtypes::INTB: + return get_load_to_common_fn_intb(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_load_to_common_fn_bool_or_byte(t); + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_load_to_common_fn_same_as_compute(t); + case SupportedTensorDtypes::SAME_AS_COMMON: + return get_load_to_common_fn_same_as_common(t); + } + ET_CHECK(false); + return nullptr; +} + +template +store_common_to_tensor_fn get_store_common_to_tensor_fn( + const Tensor& t, + SupportedTensorDtypes dtypes) { + switch (dtypes) { + case SupportedTensorDtypes::REALHBBF16: + return get_store_common_to_tensor_fn_realhbbf16(t); + case SupportedTensorDtypes::REALHBF16: + return get_store_common_to_tensor_fn_realhbf16(t); + case SupportedTensorDtypes::FLOATHBF16: + return get_store_common_to_tensor_fn_floathbf16(t); + case SupportedTensorDtypes::INTB: + return get_store_common_to_tensor_fn_intb(t); + case SupportedTensorDtypes::BOOL_OR_BYTE: + return get_store_common_to_tensor_fn_bool_or_byte( + t); + case SupportedTensorDtypes::SAME_AS_COMPUTE: + return get_store_common_to_tensor_fn_same_as_compute< + CTYPE_COMMON, + op_name>(t); + case SupportedTensorDtypes::SAME_AS_COMMON: { + return get_store_common_to_tensor_fn_same_as_common< + CTYPE_COMMON, + op_name>(t); + } + } + ET_CHECK(false); + return nullptr; +} + +bool check_tensor_dtype( + const Tensor t, + SupportedTensorDtypes dtypes, + const ScalarType compute_type); + +} // namespace internal +} // namespace utils +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/util/elementwise_util.h b/kernels/portable/cpu/util/elementwise_util.h index 3e8c3d2c6e2..3d06c7a3283 100644 --- a/kernels/portable/cpu/util/elementwise_util.h +++ b/kernels/portable/cpu/util/elementwise_util.h @@ -9,6 +9,7 @@ #pragma once #include +#include #include namespace torch { @@ -43,279 +44,6 @@ inline int64_t scalar_to(const Scalar& s) { : s.to(); } -namespace internal { - -template -To load_and_convert(const void* fromPtr) { - return static_cast(*reinterpret_cast(fromPtr)); -} - -template -void convert_and_store(From f, void* dst) { - *reinterpret_cast(dst) = static_cast(f); -} - -template -using load_to_common_fn = CTYPE_COMMON (*)(const void*); - -template -load_to_common_fn get_load_to_common_fn_realhbbf16( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_realhbf16( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_REALHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_floathbf16( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_FLOATHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_intb(const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_INT_TYPES_AND( - Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_bool_or_byte( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template -load_to_common_fn get_load_to_common_fn_same_as_compute( - const Tensor& t) { - constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::load_and_convert; -} - -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> -load_to_common_fn get_load_to_common_fn_same_as_common( - const Tensor& t) { - CTYPE_COMMON (*result)(const void*) = nullptr; - ET_SWITCH_THREE_TYPES( - Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() { - result = internal::load_and_convert; - }); - return result; -} - -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> -load_to_common_fn get_load_to_common_fn_same_as_common( - const Tensor& t) { - return get_load_to_common_fn_same_as_compute(t); -} - -template -using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*); - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_REALHBBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn get_store_common_to_tensor_fn_realhbf16( - const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_REALHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_floathbf16(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_FLOATHBF16_TYPES( - t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn get_store_common_to_tensor_fn_intb( - const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_INT_TYPES_AND( - Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_TWO_TYPES( - Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) { - constexpr auto common_scalar_type = CppTypeToScalarType::value; - ET_CHECK_MSG( - t.scalar_type() == common_scalar_type, - "Unhandled dtype %s for %s", - ::executorch::runtime::toString(common_scalar_type), - op_name); - return internal::convert_and_store; -} - -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { - void (*result)(CTYPE_COMMON, void*) = nullptr; - ET_SWITCH_THREE_TYPES( - Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() { - result = internal::convert_and_store; - }); - return result; -} - -template < - typename CTYPE_COMMON, - const char* op_name, - std::enable_if_t, bool> = true> -store_common_to_tensor_fn -get_store_common_to_tensor_fn_same_as_common(const Tensor& t) { - return get_store_common_to_tensor_fn_same_as_compute( - t); -} - -} // namespace internal - -enum class SupportedTensorDtypes { - REALHBBF16, - REALHBF16, - FLOATHBF16, - INTB, - BOOL_OR_BYTE, - SAME_AS_COMPUTE, - SAME_AS_COMMON, -}; - -namespace internal { - -template -load_to_common_fn get_load_to_common_fn( - const Tensor& t, - SupportedTensorDtypes dtypes) { - switch (dtypes) { - case SupportedTensorDtypes::REALHBBF16: - return get_load_to_common_fn_realhbbf16(t); - case SupportedTensorDtypes::REALHBF16: - return get_load_to_common_fn_realhbf16(t); - case SupportedTensorDtypes::FLOATHBF16: - return get_load_to_common_fn_realhbf16(t); - case SupportedTensorDtypes::INTB: - return get_load_to_common_fn_intb(t); - case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_load_to_common_fn_bool_or_byte(t); - case SupportedTensorDtypes::SAME_AS_COMPUTE: - return get_load_to_common_fn_same_as_compute(t); - case SupportedTensorDtypes::SAME_AS_COMMON: - return get_load_to_common_fn_same_as_common(t); - } - ET_CHECK(false); - return nullptr; -} - -template -store_common_to_tensor_fn get_store_common_to_tensor_fn( - const Tensor& t, - SupportedTensorDtypes dtypes) { - switch (dtypes) { - case SupportedTensorDtypes::REALHBBF16: - return get_store_common_to_tensor_fn_realhbbf16(t); - case SupportedTensorDtypes::REALHBF16: - return get_store_common_to_tensor_fn_realhbf16(t); - case SupportedTensorDtypes::FLOATHBF16: - return get_store_common_to_tensor_fn_floathbf16(t); - case SupportedTensorDtypes::INTB: - return get_store_common_to_tensor_fn_intb(t); - case SupportedTensorDtypes::BOOL_OR_BYTE: - return get_store_common_to_tensor_fn_bool_or_byte( - t); - case SupportedTensorDtypes::SAME_AS_COMPUTE: - return get_store_common_to_tensor_fn_same_as_compute< - CTYPE_COMMON, - op_name>(t); - case SupportedTensorDtypes::SAME_AS_COMMON: { - return get_store_common_to_tensor_fn_same_as_common< - CTYPE_COMMON, - op_name>(t); - } - } - ET_CHECK(false); - return nullptr; -} - -bool check_tensor_dtype( - const Tensor t, - SupportedTensorDtypes dtypes, - const ScalarType compute_type); - -} // namespace internal - template inline void apply_unitensor_elementwise_fn( const Op& compute_fun, diff --git a/kernels/portable/cpu/util/targets.bzl b/kernels/portable/cpu/util/targets.bzl index b182ce192b5..eb4873d1d17 100644 --- a/kernels/portable/cpu/util/targets.bzl +++ b/kernels/portable/cpu/util/targets.bzl @@ -78,15 +78,28 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/...", "@EXECUTORCH_CLIENTS"], ) + runtime.cxx_library( + name = "dtype_util", + srcs = ["dtype_util.cpp"], + exported_headers = [ + "dtype_util.h", + ], + compiler_flags = ["-Wno-missing-prototypes"], + deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], + ) + runtime.cxx_library( name = "elementwise_util", - srcs = ["elementwise_util.cpp"], exported_headers = [ "elementwise_util.h", ], compiler_flags = ["-Wno-missing-prototypes"], deps = [ ":broadcast_util", + ":dtype_util", "//executorch/kernels/portable/cpu:scalar_utils", "//executorch/runtime/kernel:kernel_includes", ], diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index b7a16235548..92b940e6305 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -214,6 +214,7 @@ ATEN_OPS = ( name = "op_add", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:kernel_ops_util", ":scalar_utils", @@ -304,6 +305,7 @@ ATEN_OPS = ( name = "op_atan2", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -325,6 +327,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -342,6 +345,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -351,6 +355,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:bitwise_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -385,6 +390,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ], @@ -457,6 +463,7 @@ ATEN_OPS = ( name = "op_div", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", @@ -474,6 +481,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -533,6 +541,7 @@ ATEN_OPS = ( name = "op_floor_divide", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ], @@ -542,6 +551,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -573,6 +583,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -598,6 +609,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -647,6 +659,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -702,6 +715,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:logical_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -718,6 +732,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:logical_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -727,6 +742,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:logical_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -742,6 +758,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -766,6 +783,7 @@ ATEN_OPS = ( name = "op_maximum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", @@ -799,6 +817,7 @@ ATEN_OPS = ( name = "op_minimum", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ":scalar_utils", @@ -815,6 +834,7 @@ ATEN_OPS = ( name = "op_mul", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ":scalar_utils", ], @@ -851,6 +871,7 @@ ATEN_OPS = ( ":scalar_utils", "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -902,6 +923,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -948,6 +970,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", "//executorch/kernels/portable/cpu/util:math_util", ], @@ -996,6 +1019,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -1114,6 +1138,7 @@ ATEN_OPS = ( deps = [ ":scalar_utils", "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), @@ -1197,6 +1222,7 @@ ATEN_OPS = ( name = "op_where", deps = [ "//executorch/kernels/portable/cpu/util:broadcast_util", + "//executorch/kernels/portable/cpu/util:dtype_util", "//executorch/kernels/portable/cpu/util:elementwise_util", ], ),