@@ -62,6 +62,7 @@ enum class SupportedDTypes : int32_t {
6262  INT32 = 3 , //  PyTorch's int32 dtype code
6363  INT64 = 4 , //  PyTorch's int64 dtype code
6464  FLOAT32 = 6 , //  PyTorch's float32 dtype code
65+   BOOL = 11 , //  PyTorch's bool dtype code
6566  BFLOAT16 = 15 , //  PyTorch's bfloat16 dtype code
6667};
6768
@@ -84,6 +85,7 @@ inline bool is_dtype_supported_in_et_cuda(int32_t dtype) {
8485    case  static_cast <int32_t >(SupportedDTypes::INT32):
8586    case  static_cast <int32_t >(SupportedDTypes::INT64):
8687    case  static_cast <int32_t >(SupportedDTypes::FLOAT32):
88+     case  static_cast <int32_t >(SupportedDTypes::BOOL):
8789    case  static_cast <int32_t >(SupportedDTypes::BFLOAT16):
8890      return  true ;
8991    default :
@@ -96,13 +98,14 @@ inline AOTITorchError validate_dtype(int32_t dtype) {
9698  ET_CHECK_OR_RETURN_ERROR (
9799      is_dtype_supported_in_et_cuda (dtype),
98100      InvalidArgument,
99-       " Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bfloat16)" 
101+       " Unsupported dtype: %d. Supported dtypes: %d (int8), %d (int16), %d (int32), %d (int64), %d (float32), %d (bool), %d ( bfloat16)" 
100102      dtype,
101103      static_cast <int32_t >(SupportedDTypes::INT8),
102104      static_cast <int32_t >(SupportedDTypes::INT16),
103105      static_cast <int32_t >(SupportedDTypes::INT32),
104106      static_cast <int32_t >(SupportedDTypes::INT64),
105107      static_cast <int32_t >(SupportedDTypes::FLOAT32),
108+       static_cast <int32_t >(SupportedDTypes::BOOL),
106109      static_cast <int32_t >(SupportedDTypes::BFLOAT16));
107110
108111  return  Error::Ok;
0 commit comments