Skip to content

Commit ca60ad8

Browse files
authored
Improve TorchAO error message (#10627)
improve error message
1 parent beacaa5 commit ca60ad8

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,15 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]
481481

482482
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
483483
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
484+
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
485+
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
486+
raise ValueError(
487+
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
488+
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
489+
)
490+
484491
raise ValueError(
485-
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
492+
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
486493
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
487494
)
488495

@@ -652,13 +659,13 @@ def get_apply_tensor_subclass(self):
652659

653660
def __repr__(self):
654661
r"""
655-
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
662+
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
656663
657664
```
658665
TorchAoConfig {
659666
"modules_to_not_convert": null,
660667
"quant_method": "torchao",
661-
"quant_type": "uint_a16w4",
668+
"quant_type": "uint4wo",
662669
"quant_type_kwargs": {
663670
"group_size": 32
664671
}

0 commit comments

Comments
 (0)