@@ -57,73 +57,29 @@ ET_CHECK_MSG(
5757 ET_CHECK_MSG (
5858 b.scalar_type () == torch_to_executorch_scalar_type (a.options ().dtype ()),
5959 " dtypes dont match a %hhd vs. b %hhd" ,
60- torch_to_executorch_scalar_type (a.options ().dtype ()),
61- b.scalar_type ());
60+ static_cast < int8_t >( torch_to_executorch_scalar_type (a.options ().dtype () )),
61+ static_cast < int8_t >( b.scalar_type () ));
6262}
6363} // namespace
6464
65- torch::executor ::ScalarType torch_to_executorch_scalar_type (
65+ executorch::runtime::etensor ::ScalarType torch_to_executorch_scalar_type (
6666 caffe2::TypeMeta type) {
67- switch (c10::typeMetaToScalarType (type)) {
68- case c10::ScalarType::Byte:
69- return torch::executor::ScalarType::Byte;
70- case c10::ScalarType::Char:
71- return torch::executor::ScalarType::Char;
72- case c10::ScalarType::Short:
73- return torch::executor::ScalarType::Short;
74- case c10::ScalarType::Half:
75- return torch::executor::ScalarType::Half;
76- case c10::ScalarType::BFloat16:
77- return torch::executor::ScalarType::BFloat16;
78- case c10::ScalarType::Int:
79- return torch::executor::ScalarType::Int;
80- case c10::ScalarType::Float:
81- return torch::executor::ScalarType::Float;
82- case c10::ScalarType::Long:
83- return torch::executor::ScalarType::Long;
84- case c10::ScalarType::Double:
85- return torch::executor::ScalarType::Double;
86- case c10::ScalarType::Bool:
87- return torch::executor::ScalarType::Bool;
88- case c10::ScalarType::QInt8:
89- return torch::executor::ScalarType::QInt8;
90- case c10::ScalarType::QUInt8:
91- return torch::executor::ScalarType::QUInt8;
92- default :
93- ET_ASSERT_UNREACHABLE ();
94- }
67+ int8_t intermediate = static_cast <int8_t >(c10::typeMetaToScalarType (type));
68+ // 29 is the latest scalartype entry ET added support for in scalar_type.h
69+ ET_CHECK_MSG (
70+ intermediate >= 0 && intermediate <= 29 ,
71+ " ScalarType %d unsupported in Executorch" , intermediate);
72+ return static_cast <executorch::runtime::etensor::ScalarType>(intermediate);
9573}
9674
9775c10::ScalarType executorch_to_torch_scalar_type (
9876 torch::executor::ScalarType type) {
99- switch (type) {
100- case torch::executor::ScalarType::Byte:
101- return c10::ScalarType::Byte;
102- case torch::executor::ScalarType::Char:
103- return c10::ScalarType::Char;
104- case torch::executor::ScalarType::Short:
105- return c10::ScalarType::Short;
106- case torch::executor::ScalarType::Half:
107- return c10::ScalarType::Half;
108- case torch::executor::ScalarType::BFloat16:
109- return c10::ScalarType::BFloat16;
110- case torch::executor::ScalarType::Int:
111- return c10::ScalarType::Int;
112- case torch::executor::ScalarType::Float:
113- return c10::ScalarType::Float;
114- case torch::executor::ScalarType::Long:
115- return c10::ScalarType::Long;
116- case torch::executor::ScalarType::Double:
117- return c10::ScalarType::Double;
118- case torch::executor::ScalarType::Bool:
119- return c10::ScalarType::Bool;
120- case torch::executor::ScalarType::QInt8:
121- return c10::ScalarType::QInt8;
122- case torch::executor::ScalarType::QUInt8:
123- return c10::ScalarType::QUInt8;
124- default :
125- ET_ASSERT_UNREACHABLE ();
126- }
77+ int8_t intermediate = static_cast <int8_t >(type);
78+ // 29 is the latest scalartype entry ET added support for in scalar_type.h
79+ ET_CHECK_MSG (
80+ intermediate >= 0 && intermediate <= 29 ,
81+ " ScalarType %d unsupported in Executorch" , intermediate);
82+ return static_cast <c10::ScalarType>(intermediate);
12783}
12884
12985/*
0 commit comments