From 0225f2fa87445a1b2326747b4f8f2abcce7108d0 Mon Sep 17 00:00:00 2001 From: Yuri Khrustalev Date: Wed, 4 Dec 2024 00:30:13 -0500 Subject: [PATCH 1/2] Remove false positive error message in the executor_runner --- kernels/portable/cpu/util/dtype_util.cpp | 11 ++++------ runtime/core/exec_aten/util/tensor_util.h | 26 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index 299910da746..99b0301aa2d 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -28,17 +28,14 @@ bool check_tensor_dtype( case SupportedTensorDtypes::INTB: return executorch::runtime::tensor_is_integral_type(t, true); case SupportedTensorDtypes::BOOL_OR_BYTE: - return ( - executorch::runtime::tensor_is_type(t, ScalarType::Bool) || - executorch::runtime::tensor_is_type(t, ScalarType::Byte)); + return (executorch::runtime::tensor_is_type( + t, {ScalarType::Bool, ScalarType::Byte})); case SupportedTensorDtypes::SAME_AS_COMPUTE: return executorch::runtime::tensor_is_type(t, compute_type); case SupportedTensorDtypes::SAME_AS_COMMON: { if (compute_type == ScalarType::Float) { - return ( - executorch::runtime::tensor_is_type(t, ScalarType::Float) || - executorch::runtime::tensor_is_type(t, ScalarType::Half) || - executorch::runtime::tensor_is_type(t, ScalarType::BFloat16)); + return (executorch::runtime::tensor_is_type( + t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16})); } else { return executorch::runtime::tensor_is_type(t, compute_type); } diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index eb57f3e099c..b1c095972c3 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -14,6 +14,8 @@ #include #include // size_t #include +#include +#include #include #include @@ -484,6 +486,30 @@ inline bool tensor_is_type( return true; } +inline bool tensor_is_type( + executorch::aten::Tensor t, + const std::vector& dtypes) { + if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) != + dtypes.end()) { + return true; + } + + std::stringstream dtype_ss; + for (size_t i = 0; i < dtypes.size(); i++) { + if (i != 0) { + dtype_ss << ", "; + } + dtype_ss << torch::executor::toString(dtypes[i]); + } + + ET_LOG_MSG_AND_RETURN_IF_FALSE( + false, + "Expected to find one of %s types, but tensor has type %s", + dtype_ss.str().c_str(), + torch::executor::toString(t.scalar_type())); + return false; +} + inline bool tensor_is_integral_type( executorch::aten::Tensor t, bool includeBool = false) { From 3e6764e4e84c6060ba37fd80ca9d16efbd515f69 Mon Sep 17 00:00:00 2001 From: Yuri Khrustalev Date: Thu, 5 Dec 2024 00:53:07 -0500 Subject: [PATCH 2/2] simplify --- kernels/portable/cpu/util/dtype_util.cpp | 4 +-- runtime/core/exec_aten/util/tensor_util.h | 41 +++++++++++++---------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/kernels/portable/cpu/util/dtype_util.cpp b/kernels/portable/cpu/util/dtype_util.cpp index 99b0301aa2d..d240b9f83bc 100644 --- a/kernels/portable/cpu/util/dtype_util.cpp +++ b/kernels/portable/cpu/util/dtype_util.cpp @@ -29,13 +29,13 @@ bool check_tensor_dtype( return executorch::runtime::tensor_is_integral_type(t, true); case SupportedTensorDtypes::BOOL_OR_BYTE: return (executorch::runtime::tensor_is_type( - t, {ScalarType::Bool, ScalarType::Byte})); + t, ScalarType::Bool, ScalarType::Byte)); case SupportedTensorDtypes::SAME_AS_COMPUTE: return executorch::runtime::tensor_is_type(t, compute_type); case SupportedTensorDtypes::SAME_AS_COMMON: { if (compute_type == ScalarType::Float) { return (executorch::runtime::tensor_is_type( - t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16})); + t, ScalarType::Float, ScalarType::Half, ScalarType::BFloat16)); } else { return executorch::runtime::tensor_is_type(t, compute_type); } diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index b1c095972c3..0575ab91988 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -14,8 +14,6 @@ #include #include // size_t #include -#include -#include #include #include @@ -488,26 +486,33 @@ inline bool tensor_is_type( inline bool tensor_is_type( executorch::aten::Tensor t, - const std::vector& dtypes) { - if (std::find(dtypes.begin(), dtypes.end(), t.scalar_type()) != - dtypes.end()) { - return true; - } + executorch::aten::ScalarType dtype, + executorch::aten::ScalarType dtype2) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + t.scalar_type() == dtype || t.scalar_type() == dtype2, + "Expected to find %s or %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(dtype2), + torch::executor::toString(t.scalar_type())); - std::stringstream dtype_ss; - for (size_t i = 0; i < dtypes.size(); i++) { - if (i != 0) { - dtype_ss << ", "; - } - dtype_ss << torch::executor::toString(dtypes[i]); - } + return true; +} +inline bool tensor_is_type( + executorch::aten::Tensor t, + executorch::aten::ScalarType dtype, + executorch::aten::ScalarType dtype2, + executorch::aten::ScalarType dtype3) { ET_LOG_MSG_AND_RETURN_IF_FALSE( - false, - "Expected to find one of %s types, but tensor has type %s", - dtype_ss.str().c_str(), + t.scalar_type() == dtype || t.scalar_type() == dtype2 || + t.scalar_type() == dtype3, + "Expected to find %s, %s, or %s type, but tensor has type %s", + torch::executor::toString(dtype), + torch::executor::toString(dtype2), + torch::executor::toString(dtype3), torch::executor::toString(t.scalar_type())); - return false; + + return true; } inline bool tensor_is_integral_type(