@@ -62,6 +62,7 @@ enum class SupportedDTypes : int32_t {
6262 INT32 = 3 , // PyTorch's int32 dtype code
6363 INT64 = 4 , // PyTorch's int64 dtype code
6464 FLOAT32 = 6 , // PyTorch's float32 dtype code
65+ BOOL = 11 , // PyTorch's bool dtype code
6566 BFLOAT16 = 15 , // PyTorch's bfloat16 dtype code
6667};
6768
@@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
8485 case static_cast <int32_t >(SupportedDTypes::INT32):
8586 case static_cast <int32_t >(SupportedDTypes::INT64):
8687 case static_cast <int32_t >(SupportedDTypes::FLOAT32):
88+ case static_cast <int32_t >(SupportedDTypes::BOOL):
8789 case static_cast <int32_t >(SupportedDTypes::BFLOAT16):
8890 return true ;
8991 default :
@@ -96,13 +98,14 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
9698 ET_CHECK_OR_RETURN_ERROR (
9799 is_dtype_supported_in_et_cuda (dtype),
98100 InvalidArgument,
99- " Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bfloat16)" ,
101+ " Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d ( bfloat16)" ,
100102 dtype,
101103 static_cast <int32_t >(SupportedDTypes::INT8),
102104 static_cast <int32_t >(SupportedDTypes::INT16),
103105 static_cast <int32_t >(SupportedDTypes::INT32),
104106 static_cast <int32_t >(SupportedDTypes::INT64),
105107 static_cast <int32_t >(SupportedDTypes::FLOAT32),
108+ static_cast <int32_t >(SupportedDTypes::BOOL),
106109 static_cast <int32_t >(SupportedDTypes::BFLOAT16));
107110
108111 return Error::Ok;
0 commit comments