Skip to content

Commit 08adaf3

Browse files
committed
[aoti-et] Add dtype bool support
1 parent 30d7cae commit 08adaf3

File tree

5 files changed

+22
-1
lines changed

5 files changed

+22
-1
lines changed

backends/aoti/common_shims.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ int32_t aoti_torch_dtype_int32() {
184184
return 3; // PyTorch's int32 dtype code
185185
}
186186

187+
int32_t aoti_torch_dtype_bool() {
188+
return 11; // PyTorch's bool dtype code
189+
}
190+
187191
int32_t aoti_torch_dtype_int64() {
188192
return 4; // PyTorch's int64 dtype code
189193
}

backends/aoti/common_shims.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ int32_t aoti_torch_dtype_int8();
6363
int32_t aoti_torch_dtype_int16();
6464
int32_t aoti_torch_dtype_int32();
6565
int32_t aoti_torch_dtype_int64();
66+
int32_t aoti_torch_dtype_bool();
6667

6768
// Dtype utility function needed by Metal backend
6869
size_t aoti_torch_dtype_element_size(int32_t dtype);

backends/aoti/tests/test_common_shims.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,14 @@ TEST_F(CommonShimsTest, IndependentCaches) {
322322
// Sizes and strides pointers should be different (different caches)
323323
EXPECT_NE(sizes_ptr1, strides_ptr1);
324324
}
325+
326+
// Test all dtype functions return correct PyTorch dtype codes
327+
TEST_F(CommonShimsTest, AllDtypesReturnCorrectValues) {
328+
EXPECT_EQ(aoti_torch_dtype_float32(), 6); // PyTorch's float32 dtype code
329+
EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // PyTorch's bfloat16 dtype code
330+
EXPECT_EQ(aoti_torch_dtype_int8(), 1); // PyTorch's int8 dtype code
331+
EXPECT_EQ(aoti_torch_dtype_int16(), 2); // PyTorch's int16 dtype code
332+
EXPECT_EQ(aoti_torch_dtype_int32(), 3); // PyTorch's int32 dtype code
333+
EXPECT_EQ(aoti_torch_dtype_int64(), 4); // PyTorch's int64 dtype code
334+
EXPECT_EQ(aoti_torch_dtype_bool(), 11); // PyTorch's bool dtype code
335+
}

backends/aoti/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
4545
return executorch::aten::ScalarType::Long;
4646
case 6: // PyTorch's float32 dtype code
4747
return executorch::aten::ScalarType::Float;
48+
case 11: // PyTorch's bool dtype code
49+
return executorch::aten::ScalarType::Bool;
4850
case 15: // PyTorch's bfloat16 dtype code
4951
return executorch::aten::ScalarType::BFloat16;
5052
// Future support for additional dtypes can be added here

backends/cuda/runtime/utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ enum class SupportedDTypes : int32_t {
6262
INT32 = 3, // PyTorch's int32 dtype code
6363
INT64 = 4, // PyTorch's int64 dtype code
6464
FLOAT32 = 6, // PyTorch's float32 dtype code
65+
BOOL = 11, // PyTorch's bool dtype code
6566
BFLOAT16 = 15, // PyTorch's bfloat16 dtype code
6667
};
6768

@@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
8485
case static_cast<int32_t>(SupportedDTypes::INT32):
8586
case static_cast<int32_t>(SupportedDTypes::INT64):
8687
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
88+
case static_cast<int32_t>(SupportedDTypes::BOOL):
8789
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
8890
return true;
8991
default:
@@ -96,13 +98,14 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
9698
ET_CHECK_OR_RETURN_ERROR(
9799
is_dtype_supported_in_et_cuda(dtype),
98100
InvalidArgument,
99-
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bfloat16)",
101+
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)",
100102
dtype,
101103
static_cast<int32_t>(SupportedDTypes::INT8),
102104
static_cast<int32_t>(SupportedDTypes::INT16),
103105
static_cast<int32_t>(SupportedDTypes::INT32),
104106
static_cast<int32_t>(SupportedDTypes::INT64),
105107
static_cast<int32_t>(SupportedDTypes::FLOAT32),
108+
static_cast<int32_t>(SupportedDTypes::BOOL),
106109
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
107110

108111
return Error::Ok;

0 commit comments

Comments
 (0)