6262 load_or_create_model_card ,
6363 populate_model_card ,
6464)
65+ from ..utils .torch_utils import device_synchronize , empty_device_cache
6566from .model_loading_utils import (
67+ _caching_allocator_warmup ,
6668 _determine_device_map ,
69+ _expand_device_map ,
6770 _fetch_index_file ,
6871 _fetch_index_file_legacy ,
72+ _find_mismatched_keys ,
6973 _load_state_dict_into_model ,
7074 load_model_dict_into_meta ,
7175 load_state_dict ,
@@ -1469,11 +1473,6 @@ def _load_pretrained_model(
14691473 for pat in cls ._keys_to_ignore_on_load_unexpected :
14701474 unexpected_keys = [k for k in unexpected_keys if re .search (pat , k ) is None ]
14711475
1472- mismatched_keys = []
1473-
1474- assign_to_params_buffers = None
1475- error_msgs = []
1476-
14771476 # Deal with offload
14781477 if device_map is not None and "disk" in device_map .values ():
14791478 if offload_folder is None :
@@ -1482,18 +1481,21 @@ def _load_pretrained_model(
14821481 " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
14831482 " offers the weights in this format."
14841483 )
1485- if offload_folder is not None :
1484+ else :
14861485 os .makedirs (offload_folder , exist_ok = True )
14871486 if offload_state_dict is None :
14881487 offload_state_dict = True
14891488
1489+ # Caching allocator warmup
1490+ if device_map is not None :
1491+ expanded_device_map = _expand_device_map (device_map , expected_keys )
1492+ _caching_allocator_warmup (model , expanded_device_map , dtype )
1493+
14901494 offload_index = {} if device_map is not None and "disk" in device_map .values () else None
1495+ state_dict_folder , state_dict_index = None , None
14911496 if offload_state_dict :
14921497 state_dict_folder = tempfile .mkdtemp ()
14931498 state_dict_index = {}
1494- else :
1495- state_dict_folder = None
1496- state_dict_index = None
14971499
14981500 if state_dict is not None :
14991501 # load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1505,14 @@ def _load_pretrained_model(
15031505 if len (resolved_model_file ) > 1 :
15041506 resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
15051507
1508+ mismatched_keys = []
1509+ assign_to_params_buffers = None
1510+ error_msgs = []
1511+
15061512 for shard_file in resolved_model_file :
15071513 state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
1508-
1509- def _find_mismatched_keys (
1510- state_dict ,
1511- model_state_dict ,
1512- loaded_keys ,
1513- ignore_mismatched_sizes ,
1514- ):
1515- mismatched_keys = []
1516- if ignore_mismatched_sizes :
1517- for checkpoint_key in loaded_keys :
1518- model_key = checkpoint_key
1519- # If the checkpoint is sharded, we may not have the key here.
1520- if checkpoint_key not in state_dict :
1521- continue
1522-
1523- if (
1524- model_key in model_state_dict
1525- and state_dict [checkpoint_key ].shape != model_state_dict [model_key ].shape
1526- ):
1527- mismatched_keys .append (
1528- (checkpoint_key , state_dict [checkpoint_key ].shape , model_state_dict [model_key ].shape )
1529- )
1530- del state_dict [checkpoint_key ]
1531- return mismatched_keys
1532-
15331514 mismatched_keys += _find_mismatched_keys (
1534- state_dict ,
1535- model_state_dict ,
1536- loaded_keys ,
1537- ignore_mismatched_sizes ,
1515+ state_dict , model_state_dict , loaded_keys , ignore_mismatched_sizes
15381516 )
15391517
15401518 if low_cpu_mem_usage :
@@ -1554,11 +1532,11 @@ def _find_mismatched_keys(
15541532 else :
15551533 if assign_to_params_buffers is None :
15561534 assign_to_params_buffers = check_support_param_buffer_assignment (model , state_dict )
1557-
15581535 error_msgs += _load_state_dict_into_model (model , state_dict , assign_to_params_buffers )
15591536
1560- torch .cuda .synchronize ()
1561-
1537+ empty_device_cache ()
1538+ device_synchronize ()
1539+
15621540 if offload_index is not None and len (offload_index ) > 0 :
15631541 save_offload_index (offload_index , offload_folder )
15641542 offload_index = None
0 commit comments