@@ -57,76 +57,41 @@ ET_CHECK_MSG(
5757 ET_CHECK_MSG (
5858 b.scalar_type () == torch_to_executorch_scalar_type (a.options ().dtype ()),
5959 " dtypes dont match a %hhd vs. b %hhd" ,
60- torch_to_executorch_scalar_type (a.options ().dtype ()),
61- b.scalar_type ());
60+ static_cast < int8_t >( torch_to_executorch_scalar_type (a.options ().dtype () )),
61+ static_cast < int8_t >( b.scalar_type () ));
6262}
6363} // namespace
6464
65- torch::executor ::ScalarType torch_to_executorch_scalar_type (
65+ executorch::runtime::etensor ::ScalarType torch_to_executorch_scalar_type (
6666 caffe2::TypeMeta type) {
67- switch (c10::typeMetaToScalarType (type)) {
68- case c10::ScalarType::Byte:
69- return torch::executor::ScalarType::Byte;
70- case c10::ScalarType::Char:
71- return torch::executor::ScalarType::Char;
72- case c10::ScalarType::Short:
73- return torch::executor::ScalarType::Short;
74- case c10::ScalarType::Half:
75- return torch::executor::ScalarType::Half;
76- case c10::ScalarType::BFloat16:
77- return torch::executor::ScalarType::BFloat16;
78- case c10::ScalarType::Int:
79- return torch::executor::ScalarType::Int;
80- case c10::ScalarType::Float:
81- return torch::executor::ScalarType::Float;
82- case c10::ScalarType::Long:
83- return torch::executor::ScalarType::Long;
84- case c10::ScalarType::Double:
85- return torch::executor::ScalarType::Double;
86- case c10::ScalarType::Bool:
87- return torch::executor::ScalarType::Bool;
88- case c10::ScalarType::QInt8:
89- return torch::executor::ScalarType::QInt8;
90- case c10::ScalarType::QUInt8:
91- return torch::executor::ScalarType::QUInt8;
92- default :
93- ET_ASSERT_UNREACHABLE_MSG (
94- " Unrecognized dtype: %hhd" ,
95- static_cast <int8_t >(c10::typeMetaToScalarType (type)));
96- }
67+ const auto intermediate =
68+ static_cast <std::underlying_type<c10::ScalarType>::type>(
69+ c10::typeMetaToScalarType (type));
70+
71+ ET_CHECK_MSG (
72+ intermediate >= 0 &&
73+ intermediate <= static_cast <std::underlying_type<
74+ executorch::runtime::etensor::ScalarType>::type>(
75+ executorch::runtime::etensor::ScalarType::UInt64),
76+ " ScalarType %d unsupported in Executorch" ,
77+ intermediate);
78+ return static_cast <executorch::runtime::etensor::ScalarType>(intermediate);
9779}
9880
9981c10::ScalarType executorch_to_torch_scalar_type (
10082 torch::executor::ScalarType type) {
101- switch (type) {
102- case torch::executor::ScalarType::Byte:
103- return c10::ScalarType::Byte;
104- case torch::executor::ScalarType::Char:
105- return c10::ScalarType::Char;
106- case torch::executor::ScalarType::Short:
107- return c10::ScalarType::Short;
108- case torch::executor::ScalarType::Half:
109- return c10::ScalarType::Half;
110- case torch::executor::ScalarType::BFloat16:
111- return c10::ScalarType::BFloat16;
112- case torch::executor::ScalarType::Int:
113- return c10::ScalarType::Int;
114- case torch::executor::ScalarType::Float:
115- return c10::ScalarType::Float;
116- case torch::executor::ScalarType::Long:
117- return c10::ScalarType::Long;
118- case torch::executor::ScalarType::Double:
119- return c10::ScalarType::Double;
120- case torch::executor::ScalarType::Bool:
121- return c10::ScalarType::Bool;
122- case torch::executor::ScalarType::QInt8:
123- return c10::ScalarType::QInt8;
124- case torch::executor::ScalarType::QUInt8:
125- return c10::ScalarType::QUInt8;
126- default :
127- ET_ASSERT_UNREACHABLE_MSG (
128- " Unrecognized dtype: %hhd" , static_cast <int8_t >(type));
129- }
83+ const auto intermediate = static_cast <
84+ std::underlying_type<executorch::runtime::etensor::ScalarType>::type>(
85+ type);
86+
87+ ET_CHECK_MSG (
88+ intermediate >= 0 &&
89+ intermediate <= static_cast <std::underlying_type<
90+ executorch::runtime::etensor::ScalarType>::type>(
91+ executorch::runtime::etensor::ScalarType::UInt64),
92+ " ScalarType %d unsupported in Executorch" ,
93+ intermediate);
94+ return static_cast <c10::ScalarType>(intermediate);
13095}
13196
13297/*
0 commit comments