@@ -493,7 +493,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
493493 TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
494494 if self .quant_type not in TORCHAO_QUANT_TYPE_METHODS .keys ():
495495 is_floating_quant_type = self .quant_type .startswith ("float" ) or self .quant_type .startswith ("fp" )
496- if is_floating_quant_type and not self ._is_cuda_capability_atleast_8_9 ():
496+ if is_floating_quant_type and not self ._is_xpu_or_cuda_capability_atleast_8_9 ():
497497 raise ValueError (
498498 f"Requested quantization type: { self .quant_type } is not supported on GPUs with CUDA capability <= 8.9. You "
499499 f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
@@ -645,7 +645,7 @@ def generate_fpx_quantization_types(bits: int):
645645 QUANTIZATION_TYPES .update (INT8_QUANTIZATION_TYPES )
646646 QUANTIZATION_TYPES .update (UINTX_QUANTIZATION_DTYPES )
647647
648- if cls ._is_cuda_capability_atleast_8_9 ():
648+ if cls ._is_xpu_or_cuda_capability_atleast_8_9 ():
649649 QUANTIZATION_TYPES .update (FLOATX_QUANTIZATION_TYPES )
650650
651651 return QUANTIZATION_TYPES
@@ -655,14 +655,16 @@ def generate_fpx_quantization_types(bits: int):
655655 )
656656
657657 @staticmethod
658- def _is_cuda_capability_atleast_8_9 () -> bool :
658+ def _is_xpu_or_cuda_capability_atleast_8_9 () -> bool :
659659 if torch .cuda .is_available ():
660660 major , minor = torch .cuda .get_device_capability ()
661661 if major == 8 :
662662 return minor >= 9
663663 return major >= 9
664- else :
664+ elif torch . xpu . is_available () :
665665 return True
666+ else :
667+ raise RuntimeError ("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch." )
666668
667669 def get_apply_tensor_subclass (self ):
668670 TORCHAO_QUANT_TYPE_METHODS = self ._get_torchao_quant_type_to_method ()
0 commit comments