Skip to content

Commit c6651f9

Browse files
committed
udpate
1 parent 1a29a99 commit c6651f9

File tree

2 files changed

+146
-132
lines changed

2 files changed

+146
-132
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -719,9 +719,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
719719

720720
if hf_quantizer is not None:
721721
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
722-
if is_bnb_quantization_method and device_map is not None:
722+
is_torchao_quantization_method = hf_quantizer.quantization_config.quant_method.value == "torchao"
723+
if (is_bnb_quantization_method or is_torchao_quantization_method) and device_map is not None:
723724
raise NotImplementedError(
724-
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
725+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes and torchao models. Support for providing `device_map` as an input will be added in the future."
725726
)
726727

727728
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)

0 commit comments

Comments
 (0)