Skip to content

Commit 997e56c

Browse files
committed
add sharded + device_map check
1 parent 0d96a89 commit 997e56c

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
803803
subfolder=subfolder or "",
804804
)
805805
if hf_quantizer is not None:
806+
is_torchao_quantization_method = quantization_config.quant_method == QuantizationMethod.TORCHAO
807+
if device_map is not None and is_torchao_quantization_method:
808+
raise NotImplementedError(
809+
"Loading sharded checkpoints, while passing `device_map`, is not supported with `torchao` quantization. This will be supported in the near future."
810+
)
811+
806812
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
807813
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
808814
is_sharded = False

0 commit comments

Comments
 (0)