Skip to content

Commit edd98db

Browse files
committed
bnb device map check
1 parent 101d10c commit edd98db

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
671671
hf_quantizer = None
672672

673673
if hf_quantizer is not None:
674+
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
675+
if is_bnb_quantization_method and device_map is not None:
676+
raise NotImplementedError(
677+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
678+
)
679+
674680
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
675681
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
676682

0 commit comments

Comments
 (0)