File tree Expand file tree Collapse file tree 2 files changed +30
-7
lines changed
kernels/portable/cpu/util
runtime/core/exec_aten/util Expand file tree Collapse file tree 2 files changed +30
-7
lines changed Original file line number Diff line number Diff line change @@ -28,17 +28,14 @@ bool check_tensor_dtype(
2828 case SupportedTensorDtypes::INTB:
2929 return executorch::runtime::tensor_is_integral_type (t, true );
3030 case SupportedTensorDtypes::BOOL_OR_BYTE:
31- return (
32- executorch::runtime::tensor_is_type (t, ScalarType::Bool) ||
33- executorch::runtime::tensor_is_type (t, ScalarType::Byte));
31+ return (executorch::runtime::tensor_is_type (
32+ t, {ScalarType::Bool, ScalarType::Byte}));
3433 case SupportedTensorDtypes::SAME_AS_COMPUTE:
3534 return executorch::runtime::tensor_is_type (t, compute_type);
3635 case SupportedTensorDtypes::SAME_AS_COMMON: {
3736 if (compute_type == ScalarType::Float) {
38- return (
39- executorch::runtime::tensor_is_type (t, ScalarType::Float) ||
40- executorch::runtime::tensor_is_type (t, ScalarType::Half) ||
41- executorch::runtime::tensor_is_type (t, ScalarType::BFloat16));
37+ return (executorch::runtime::tensor_is_type (
38+ t, {ScalarType::Float, ScalarType::Half, ScalarType::BFloat16}));
4239 } else {
4340 return executorch::runtime::tensor_is_type (t, compute_type);
4441 }
Original file line number Diff line number Diff line change 1414#include < cmath>
1515#include < cstddef> // size_t
1616#include < limits>
17+ #include < sstream>
18+ #include < vector>
1719
1820#include < executorch/runtime/core/array_ref.h>
1921#include < executorch/runtime/core/error.h>
@@ -484,6 +486,30 @@ inline bool tensor_is_type(
484486 return true ;
485487}
486488
489+ inline bool tensor_is_type (
490+ executorch::aten::Tensor t,
491+ const std::vector<executorch::aten::ScalarType>& dtypes) {
492+ if (std::find (dtypes.begin (), dtypes.end (), t.scalar_type ()) !=
493+ dtypes.end ()) {
494+ return true ;
495+ }
496+
497+ std::stringstream dtype_ss;
498+ for (size_t i = 0 ; i < dtypes.size (); i++) {
499+ if (i != 0 ) {
500+ dtype_ss << " , " ;
501+ }
502+ dtype_ss << torch::executor::toString (dtypes[i]);
503+ }
504+
505+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
506+ false ,
507+ " Expected to find one of %s types, but tensor has type %s" ,
508+ dtype_ss.str ().c_str (),
509+ torch::executor::toString (t.scalar_type ()));
510+ return false ;
511+ }
512+
487513inline bool tensor_is_integral_type (
488514 executorch::aten::Tensor t,
489515 bool includeBool = false ) {
You can’t perform that action at this time.
0 commit comments