@@ -1185,8 +1185,16 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
11851185
11861186 state_dict = None
11871187 if not is_sharded :
1188+ map_location = "cpu"
1189+ if (
1190+ device_map is not None
1191+ and hf_quantizer is not None
1192+ and hf_quantizer .quantization_config .quant_method == QuantizationMethod .TORCHAO
1193+ and hf_quantizer .quantization_config .quant_type in ["int4_weight_only" , "autoquant" ]
1194+ ):
1195+ map_location = torch .device ([d for d in device_map .values () if d not in ["cpu" , "disk" ]][0 ])
11881196 # Time to load the checkpoint
1189- state_dict = load_state_dict (resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries )
1197+ state_dict = load_state_dict (resolved_model_file [0 ], disable_mmap = disable_mmap , dduf_entries = dduf_entries , map_location = map_location )
11901198 # We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
11911199 model ._fix_state_dict_keys_on_load (state_dict )
11921200
@@ -1443,10 +1451,9 @@ def _load_pretrained_model(
14431451 device_map is not None
14441452 and hf_quantizer is not None
14451453 and hf_quantizer .quantization_config .quant_method == QuantizationMethod .TORCHAO
1446- and hf_quantizer .quantfization_config .quant_type in ["int4_weight_only" , "autoquant" ]
1454+ and hf_quantizer .quantization_config .quant_type in ["int4_weight_only" , "autoquant" ]
14471455 ):
14481456 map_location = torch .device ([d for d in device_map .values () if d not in ["cpu" , "disk" ]][0 ])
1449-
14501457 for shard_file in resolved_model_file :
14511458 state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries , map_location = map_location )
14521459
0 commit comments