Skip to content

Commit d0b718a

Browse files
committed
apply review suggestions
1 parent 1873bb7 commit d0b718a

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
718718
hf_quantizer = None
719719

720720
if hf_quantizer is not None:
721-
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
722-
is_torchao_quantization_method = hf_quantizer.quantization_config.quant_method.value == "torchao"
723-
if (is_bnb_quantization_method or is_torchao_quantization_method) and device_map is not None:
721+
if device_map is not None:
724722
raise NotImplementedError(
725-
"Currently, `device_map` is automatically inferred for quantized bitsandbytes and torchao models. Support for providing `device_map` as an input will be added in the future."
723+
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
726724
)
727725

728726
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)

tests/quantization/torchao/test_torchao.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ def test_torch_compile(self):
418418
quantization_config = TorchAoConfig("int8_weight_only")
419419
components = self.get_dummy_components(quantization_config, model_id=model_id)
420420
pipe = FluxPipeline(**components)
421-
pipe.to(device=torch_device, dtype=torch.bfloat16)
421+
pipe.to(device=torch_device)
422422

423423
inputs = self.get_dummy_inputs(torch_device)
424424
normal_output = pipe(**inputs)[0].flatten()[-32:]

0 commit comments

Comments
 (0)