Skip to content

Commit cad4954

Browse files
committed
quick fix
1 parent a99663d commit cad4954

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,8 +1185,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11851185

11861186
state_dict = None
11871187
if not is_sharded:
1188+
map_location = "cpu"
1189+
if (
1190+
device_map is not None
1191+
and hf_quantizer is not None
1192+
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
1193+
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
1194+
):
1195+
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
11881196
# Time to load the checkpoint
1189-
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
1197+
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location)
11901198
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
11911199
model._fix_state_dict_keys_on_load(state_dict)
11921200

@@ -1443,10 +1451,9 @@ def _load_pretrained_model(
14431451
device_map is not None
14441452
and hf_quantizer is not None
14451453
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
1446-
and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"]
1454+
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
14471455
):
14481456
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
1449-
14501457
for shard_file in resolved_model_file:
14511458
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location)
14521459

0 commit comments

Comments
 (0)