From 08adaf3d15e8ad9c4444f34051cd472d4eabdd32 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 28 Oct 2025 16:45:38 -0700 Subject: [PATCH] [aoti-et] Add dtype bool support --- backends/aoti/common_shims.cpp | 4 ++++ backends/aoti/common_shims.h | 1 + backends/aoti/tests/test_common_shims.cpp | 11 +++++++++++ backends/aoti/utils.h | 2 ++ backends/cuda/runtime/utils.h | 5 ++++- 5 files changed, 22 insertions(+), 1 deletion(-) diff --git a/backends/aoti/common_shims.cpp b/backends/aoti/common_shims.cpp index ac87d49d5a5..deb10478778 100644 --- a/backends/aoti/common_shims.cpp +++ b/backends/aoti/common_shims.cpp @@ -184,6 +184,10 @@ int32_t aoti_torch_dtype_int32() { return 3; // PyTorch's int32 dtype code } +int32_t aoti_torch_dtype_bool() { + return 11; // PyTorch's bool dtype code +} + int32_t aoti_torch_dtype_int64() { return 4; // PyTorch's int64 dtype code } diff --git a/backends/aoti/common_shims.h b/backends/aoti/common_shims.h index 1b0429e3aba..91bb785b684 100644 --- a/backends/aoti/common_shims.h +++ b/backends/aoti/common_shims.h @@ -63,6 +63,7 @@ int32_t aoti_torch_dtype_int8(); int32_t aoti_torch_dtype_int16(); int32_t aoti_torch_dtype_int32(); int32_t aoti_torch_dtype_int64(); +int32_t aoti_torch_dtype_bool(); // Dtype utility function needed by Metal backend size_t aoti_torch_dtype_element_size(int32_t dtype); diff --git a/backends/aoti/tests/test_common_shims.cpp b/backends/aoti/tests/test_common_shims.cpp index 980eae96122..0fd1b057f99 100644 --- a/backends/aoti/tests/test_common_shims.cpp +++ b/backends/aoti/tests/test_common_shims.cpp @@ -322,3 +322,14 @@ TEST_F(CommonShimsTest, IndependentCaches) { // Sizes and strides pointers should be different (different caches) EXPECT_NE(sizes_ptr1, strides_ptr1); } + +// Test all dtype functions return correct PyTorch dtype codes +TEST_F(CommonShimsTest, AllDtypesReturnCorrectValues) { + EXPECT_EQ(aoti_torch_dtype_float32(), 6); // PyTorch's float32 dtype code + EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // PyTorch's bfloat16 dtype code + EXPECT_EQ(aoti_torch_dtype_int8(), 1); // PyTorch's int8 dtype code + EXPECT_EQ(aoti_torch_dtype_int16(), 2); // PyTorch's int16 dtype code + EXPECT_EQ(aoti_torch_dtype_int32(), 3); // PyTorch's int32 dtype code + EXPECT_EQ(aoti_torch_dtype_int64(), 4); // PyTorch's int64 dtype code + EXPECT_EQ(aoti_torch_dtype_bool(), 11); // PyTorch's bool dtype code +} diff --git a/backends/aoti/utils.h b/backends/aoti/utils.h index 80abe663fa2..8f64bdbe7da 100644 --- a/backends/aoti/utils.h +++ b/backends/aoti/utils.h @@ -45,6 +45,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) { return executorch::aten::ScalarType::Long; case 6: // PyTorch's float32 dtype code return executorch::aten::ScalarType::Float; + case 11: // PyTorch's bool dtype code + return executorch::aten::ScalarType::Bool; case 15: // PyTorch's bfloat16 dtype code return executorch::aten::ScalarType::BFloat16; // Future support for additional dtypes can be added here diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index 04c1a43721a..4474f8cf57e 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -62,6 +62,7 @@ enum class SupportedDTypes : int32_t { INT32 = 3, // PyTorch's int32 dtype code INT64 = 4, // PyTorch's int64 dtype code FLOAT32 = 6, // PyTorch's float32 dtype code + BOOL = 11, // PyTorch's bool dtype code BFLOAT16 = 15, // PyTorch's bfloat16 dtype code }; @@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) { case static_cast(SupportedDTypes::INT32): case static_cast(SupportedDTypes::INT64): case static_cast(SupportedDTypes::FLOAT32): + case static_cast(SupportedDTypes::BOOL): case static_cast(SupportedDTypes::BFLOAT16): return true; default: @@ -96,13 +98,14 @@ inline AOTITorchError validate_dtype(int32_t dtype) { ET_CHECK_OR_RETURN_ERROR( is_dtype_supported_in_et_cuda(dtype), InvalidArgument, - "Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bfloat16)", + "Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)", dtype, static_cast(SupportedDTypes::INT8), static_cast(SupportedDTypes::INT16), static_cast(SupportedDTypes::INT32), static_cast(SupportedDTypes::INT64), static_cast(SupportedDTypes::FLOAT32), + static_cast(SupportedDTypes::BOOL), static_cast(SupportedDTypes::BFLOAT16)); return Error::Ok;