From 9c101afd4047acd9301c477803644c63e85dfd87 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Thu, 14 Nov 2024 13:46:04 -0800 Subject: [PATCH] expand dtype conversion support in aten_bridge (#6845) Summary: The dtype tables by necessity have to match exaclty so just casting to the int and then recasting to the other enum is safe Reviewed By: dulinriley Differential Revision: D65897501 --- extension/aten_util/aten_bridge.cpp | 89 +++++++++-------------------- 1 file changed, 27 insertions(+), 62 deletions(-) diff --git a/extension/aten_util/aten_bridge.cpp b/extension/aten_util/aten_bridge.cpp index 3bb7d93fa5d..f13047ab8fa 100644 --- a/extension/aten_util/aten_bridge.cpp +++ b/extension/aten_util/aten_bridge.cpp @@ -57,76 +57,41 @@ ET_CHECK_MSG( ET_CHECK_MSG( b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()), "dtypes dont match a %hhd vs. b %hhd", - torch_to_executorch_scalar_type(a.options().dtype()), - b.scalar_type()); + static_cast(torch_to_executorch_scalar_type(a.options().dtype())), + static_cast(b.scalar_type())); } } // namespace -torch::executor::ScalarType torch_to_executorch_scalar_type( +executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type( caffe2::TypeMeta type) { - switch (c10::typeMetaToScalarType(type)) { - case c10::ScalarType::Byte: - return torch::executor::ScalarType::Byte; - case c10::ScalarType::Char: - return torch::executor::ScalarType::Char; - case c10::ScalarType::Short: - return torch::executor::ScalarType::Short; - case c10::ScalarType::Half: - return torch::executor::ScalarType::Half; - case c10::ScalarType::BFloat16: - return torch::executor::ScalarType::BFloat16; - case c10::ScalarType::Int: - return torch::executor::ScalarType::Int; - case c10::ScalarType::Float: - return torch::executor::ScalarType::Float; - case c10::ScalarType::Long: - return torch::executor::ScalarType::Long; - case c10::ScalarType::Double: - return torch::executor::ScalarType::Double; - case c10::ScalarType::Bool: - return torch::executor::ScalarType::Bool; - case c10::ScalarType::QInt8: - return torch::executor::ScalarType::QInt8; - case c10::ScalarType::QUInt8: - return torch::executor::ScalarType::QUInt8; - default: - ET_ASSERT_UNREACHABLE_MSG( - "Unrecognized dtype: %hhd", - static_cast(c10::typeMetaToScalarType(type))); - } + const auto intermediate = + static_cast::type>( + c10::typeMetaToScalarType(type)); + + ET_CHECK_MSG( + intermediate >= 0 && + intermediate <= static_cast::type>( + executorch::runtime::etensor::ScalarType::UInt64), + "ScalarType %d unsupported in Executorch", + intermediate); + return static_cast(intermediate); } c10::ScalarType executorch_to_torch_scalar_type( torch::executor::ScalarType type) { - switch (type) { - case torch::executor::ScalarType::Byte: - return c10::ScalarType::Byte; - case torch::executor::ScalarType::Char: - return c10::ScalarType::Char; - case torch::executor::ScalarType::Short: - return c10::ScalarType::Short; - case torch::executor::ScalarType::Half: - return c10::ScalarType::Half; - case torch::executor::ScalarType::BFloat16: - return c10::ScalarType::BFloat16; - case torch::executor::ScalarType::Int: - return c10::ScalarType::Int; - case torch::executor::ScalarType::Float: - return c10::ScalarType::Float; - case torch::executor::ScalarType::Long: - return c10::ScalarType::Long; - case torch::executor::ScalarType::Double: - return c10::ScalarType::Double; - case torch::executor::ScalarType::Bool: - return c10::ScalarType::Bool; - case torch::executor::ScalarType::QInt8: - return c10::ScalarType::QInt8; - case torch::executor::ScalarType::QUInt8: - return c10::ScalarType::QUInt8; - default: - ET_ASSERT_UNREACHABLE_MSG( - "Unrecognized dtype: %hhd", static_cast(type)); - } + const auto intermediate = static_cast< + std::underlying_type::type>( + type); + + ET_CHECK_MSG( + intermediate >= 0 && + intermediate <= static_cast::type>( + executorch::runtime::etensor::ScalarType::UInt64), + "ScalarType %d unsupported in Executorch", + intermediate); + return static_cast(intermediate); } /*