Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,8 +481,15 @@ def __init__(self, quant_type: str, modules_to_not_convert: Optional[List[str]]

TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)

raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)

Expand Down Expand Up @@ -652,13 +659,13 @@ def get_apply_tensor_subclass(self):

def __repr__(self):
r"""
Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:

```
TorchAoConfig {
"modules_to_not_convert": null,
"quant_method": "torchao",
"quant_type": "uint_a16w4",
"quant_type": "uint4wo",
"quant_type_kwargs": {
"group_size": 32
}
Expand Down
Loading