diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index 3bb7d93fa5d..f13047ab8fa 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -57,76 +57,41 @@ ET_CHECK_MSG( ET_CHECK_MSG( b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()), "dtypes dont match a %hhd vs. b %hhd", - torch_to_executorch_scalar_type(a.options().dtype()), - b.scalar_type()); + static_cast(torch_to_executorch_scalar_type(a.options().dtype())), + static_cast(b.scalar_type())); } } // namespace -torch::executor::ScalarType torch_to_executorch_scalar_type( +executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type( caffe2::TypeMeta type) { - switch (c10::typeMetaToScalarType(type)) { - case c10::ScalarType::Byte: - return torch::executor::ScalarType::Byte; - case c10::ScalarType::Char: - return torch::executor::ScalarType::Char; - case c10::ScalarType::Short: - return torch::executor::ScalarType::Short; - case c10::ScalarType::Half: - return torch::executor::ScalarType::Half; - case c10::ScalarType::BFloat16: - return torch::executor::ScalarType::BFloat16; - case c10::ScalarType::Int: - return torch::executor::ScalarType::Int; - case c10::ScalarType::Float: - return torch::executor::ScalarType::Float; - case c10::ScalarType::Long: - return torch::executor::ScalarType::Long; - case c10::ScalarType::Double: - return torch::executor::ScalarType::Double; - case c10::ScalarType::Bool: - return torch::executor::ScalarType::Bool; - case c10::ScalarType::QInt8: - return torch::executor::ScalarType::QInt8; - case c10::ScalarType::QUInt8: - return torch::executor::ScalarType::QUInt8; - default: - ET_ASSERT_UNREACHABLE_MSG( - "Unrecognized dtype: %hhd", - static_cast(c10::typeMetaToScalarType(type))); - } + const auto intermediate = + static_cast::type>( + c10::typeMetaToScalarType(type)); + + ET_CHECK_MSG( + intermediate >= 0 && + intermediate <= static_cast::type>( + executorch::runtime::etensor::ScalarType::UInt64), + "ScalarType %d unsupported in Executorch", + intermediate); + return static_cast(intermediate); } c10::ScalarType executorch_to_torch_scalar_type( torch::executor::ScalarType type) { - switch (type) { - case torch::executor::ScalarType::Byte: - return c10::ScalarType::Byte; - case torch::executor::ScalarType::Char: - return c10::ScalarType::Char; - case torch::executor::ScalarType::Short: - return c10::ScalarType::Short; - case torch::executor::ScalarType::Half: - return c10::ScalarType::Half; - case torch::executor::ScalarType::BFloat16: - return c10::ScalarType::BFloat16; - case torch::executor::ScalarType::Int: - return c10::ScalarType::Int; - case torch::executor::ScalarType::Float: - return c10::ScalarType::Float; - case torch::executor::ScalarType::Long: - return c10::ScalarType::Long; - case torch::executor::ScalarType::Double: - return c10::ScalarType::Double; - case torch::executor::ScalarType::Bool: - return c10::ScalarType::Bool; - case torch::executor::ScalarType::QInt8: - return c10::ScalarType::QInt8; - case torch::executor::ScalarType::QUInt8: - return c10::ScalarType::QUInt8; - default: - ET_ASSERT_UNREACHABLE_MSG( - "Unrecognized dtype: %hhd", static_cast(type)); - } + const auto intermediate = static_cast< + std::underlying_type::type>( + type); + + ET_CHECK_MSG( + intermediate >= 0 && + intermediate <= static_cast::type>( + executorch::runtime::etensor::ScalarType::UInt64), + "ScalarType %d unsupported in Executorch", + intermediate); + return static_cast(intermediate); } /*