Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/aoti/common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ int32_t aoti_torch_dtype_int32() {
return 3; // PyTorch's int32 dtype code
}

int32_t aoti_torch_dtype_bool() {
return 11; // PyTorch's bool dtype code
}

int32_t aoti_torch_dtype_int64() {
return 4; // PyTorch's int64 dtype code
}
Expand Down
1 change: 1 addition & 0 deletions backends/aoti/common_shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ int32_t aoti_torch_dtype_int8();
int32_t aoti_torch_dtype_int16();
int32_t aoti_torch_dtype_int32();
int32_t aoti_torch_dtype_int64();
int32_t aoti_torch_dtype_bool();

// Dtype utility function needed by Metal backend
size_t aoti_torch_dtype_element_size(int32_t dtype);
Expand Down
11 changes: 11 additions & 0 deletions backends/aoti/tests/test_common_shims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,14 @@ TEST_F(CommonShimsTest, IndependentCaches) {
// Sizes and strides pointers should be different (different caches)
EXPECT_NE(sizes_ptr1, strides_ptr1);
}

// Test all dtype functions return correct PyTorch dtype codes
TEST_F(CommonShimsTest, AllDtypesReturnCorrectValues) {
EXPECT_EQ(aoti_torch_dtype_float32(), 6); // PyTorch's float32 dtype code
EXPECT_EQ(aoti_torch_dtype_bfloat16(), 15); // PyTorch's bfloat16 dtype code
EXPECT_EQ(aoti_torch_dtype_int8(), 1); // PyTorch's int8 dtype code
EXPECT_EQ(aoti_torch_dtype_int16(), 2); // PyTorch's int16 dtype code
EXPECT_EQ(aoti_torch_dtype_int32(), 3); // PyTorch's int32 dtype code
EXPECT_EQ(aoti_torch_dtype_int64(), 4); // PyTorch's int64 dtype code
EXPECT_EQ(aoti_torch_dtype_bool(), 11); // PyTorch's bool dtype code
}
2 changes: 2 additions & 0 deletions backends/aoti/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ inline executorch::aten::ScalarType dtype_to_scalar_type(int32_t dtype) {
return executorch::aten::ScalarType::Long;
case 6: // PyTorch's float32 dtype code
return executorch::aten::ScalarType::Float;
case 11: // PyTorch's bool dtype code
return executorch::aten::ScalarType::Bool;
case 15: // PyTorch's bfloat16 dtype code
return executorch::aten::ScalarType::BFloat16;
// Future support for additional dtypes can be added here
Expand Down
5 changes: 4 additions & 1 deletion backends/cuda/runtime/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ enum class SupportedDTypes : int32_t {
INT32 = 3, // PyTorch's int32 dtype code
INT64 = 4, // PyTorch's int64 dtype code
FLOAT32 = 6, // PyTorch's float32 dtype code
BOOL = 11, // PyTorch's bool dtype code
BFLOAT16 = 15, // PyTorch's bfloat16 dtype code
};

Expand All @@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
case static_cast<int32_t>(SupportedDTypes::INT32):
case static_cast<int32_t>(SupportedDTypes::INT64):
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
case static_cast<int32_t>(SupportedDTypes::BOOL):
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
return true;
default:
Expand All @@ -96,13 +98,14 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
ET_CHECK_OR_RETURN_ERROR(
is_dtype_supported_in_et_cuda(dtype),
InvalidArgument,
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bfloat16)",
"Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d (bfloat16)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any macro we can use here? seems it is getting longer and longer loll

dtype,
static_cast<int32_t>(SupportedDTypes::INT8),
static_cast<int32_t>(SupportedDTypes::INT16),
static_cast<int32_t>(SupportedDTypes::INT32),
static_cast<int32_t>(SupportedDTypes::INT64),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDTypes::BOOL),
static_cast<int32_t>(SupportedDTypes::BFLOAT16));

return Error::Ok;
Expand Down
Loading