Skip to content

Commit a99663d

Browse files
committed
load tensors on cuda
1 parent 05c8b42 commit a99663d

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1438,8 +1438,17 @@ def _load_pretrained_model(
14381438
if len(resolved_model_file) > 1:
14391439
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
14401440

1441+
map_location = "cpu"
1442+
if (
1443+
device_map is not None
1444+
and hf_quantizer is not None
1445+
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
1446+
and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"]
1447+
):
1448+
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
1449+
14411450
for shard_file in resolved_model_file:
1442-
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1451+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location)
14431452

14441453
def _find_mismatched_keys(
14451454
state_dict,

0 commit comments

Comments
 (0)