File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change @@ -1438,8 +1438,17 @@ def _load_pretrained_model(
14381438 if len (resolved_model_file ) > 1 :
14391439 resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
14401440
1441+ map_location = "cpu"
1442+ if (
1443+ device_map is not None
1444+ and hf_quantizer is not None
1445+ and hf_quantizer .quantization_config .quant_method == QuantizationMethod .TORCHAO
1446+ and hf_quantizer .quantfization_config .quant_type in ["int4_weight_only" , "autoquant" ]
1447+ ):
1448+ map_location = torch .device ([d for d in device_map .values () if d not in ["cpu" , "disk" ]][0 ])
1449+
14411450 for shard_file in resolved_model_file :
1442- state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
1451+ state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries , map_location = map_location )
14431452
14441453 def _find_mismatched_keys (
14451454 state_dict ,
You can’t perform that action at this time.
0 commit comments