Skip to content

Commit d7fda4b

Browse files
committed
resolve comments
Signed-off-by: YAO Matrix <[email protected]>
1 parent 08e8038 commit d7fda4b

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
493493
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
494494
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
495495
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
496-
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
496+
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
497497
raise ValueError(
498498
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
499499
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
@@ -645,7 +645,7 @@ def generate_fpx_quantization_types(bits: int):
645645
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
646646
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
647647

648-
if cls._is_cuda_capability_atleast_8_9():
648+
if cls._is_xpu_or_cuda_capability_atleast_8_9():
649649
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
650650

651651
return QUANTIZATION_TYPES
@@ -655,14 +655,16 @@ def generate_fpx_quantization_types(bits: int):
655655
)
656656

657657
@staticmethod
658-
def _is_cuda_capability_atleast_8_9() -> bool:
658+
def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
659659
if torch.cuda.is_available():
660660
major, minor = torch.cuda.get_device_capability()
661661
if major == 8:
662662
return minor >= 9
663663
return major >= 9
664-
else:
664+
elif torch.xpu.is_available():
665665
return True
666+
else:
667+
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
666668

667669
def get_apply_tensor_subclass(self):
668670
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()

0 commit comments

Comments
 (0)