6666 _determine_device_map ,
6767 _fetch_index_file ,
6868 _fetch_index_file_legacy ,
69- _load_state_dict_into_model ,
70- load_model_dict_into_meta ,
69+ load_shard_file ,
70+ load_shard_files_with_threadpool ,
7171 load_state_dict ,
7272)
7373
@@ -200,34 +200,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
200200 return last_tuple [1 ].dtype
201201
202202
203- def check_support_param_buffer_assignment (model_to_load , state_dict , start_prefix = "" ):
204- """
205- Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
206- checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
207- parameters.
208-
209- """
210- if model_to_load .device .type == "meta" :
211- return False
212-
213- if len ([key for key in state_dict if key .startswith (start_prefix )]) == 0 :
214- return False
215-
216- # Some models explicitly do not support param buffer assignment
217- if not getattr (model_to_load , "_supports_param_buffer_assignment" , True ):
218- logger .debug (
219- f"{ model_to_load .__class__ .__name__ } does not support param buffer assignment, loading will be slower"
220- )
221- return False
222-
223- # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
224- first_key = next (iter (model_to_load .state_dict ().keys ()))
225- if start_prefix + first_key in state_dict :
226- return state_dict [start_prefix + first_key ].dtype == model_to_load .state_dict ()[first_key ].dtype
227-
228- return False
229-
230-
231203@contextmanager
232204def no_init_weights ():
233205 """
@@ -926,6 +898,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
926898 dduf_entries : Optional [Dict [str , DDUFEntry ]] = kwargs .pop ("dduf_entries" , None )
927899 disable_mmap = kwargs .pop ("disable_mmap" , False )
928900
901+ # TODO: enable TRUE ENV VARs
902+ is_parallel_loading_enabled = bool (os .environ .get ("HF_ENABLE_PARALLEL_LOADING" , 1 ))
903+
904+ if is_parallel_loading_enabled and not low_cpu_mem_usage :
905+ raise NotImplementedError ("Parallel loading is not supported when not using `low_cpu_mem_usage`." )
906+
929907 if torch_dtype is not None and not isinstance (torch_dtype , torch .dtype ):
930908 torch_dtype = torch .float32
931909 logger .warning (
@@ -1261,6 +1239,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12611239 hf_quantizer = hf_quantizer ,
12621240 keep_in_fp32_modules = keep_in_fp32_modules ,
12631241 dduf_entries = dduf_entries ,
1242+ is_parallel_loading_enabled = is_parallel_loading_enabled ,
12641243 )
12651244 loading_info = {
12661245 "missing_keys" : missing_keys ,
@@ -1456,6 +1435,7 @@ def _load_pretrained_model(
14561435 offload_state_dict : Optional [bool ] = None ,
14571436 offload_folder : Optional [Union [str , os .PathLike ]] = None ,
14581437 dduf_entries : Optional [Dict [str , DDUFEntry ]] = None ,
1438+ is_parallel_loading_enabled : Optional [bool ] = False ,
14591439 ):
14601440 model_state_dict = model .state_dict ()
14611441 expected_keys = list (model_state_dict .keys ())
@@ -1470,8 +1450,6 @@ def _load_pretrained_model(
14701450 unexpected_keys = [k for k in unexpected_keys if re .search (pat , k ) is None ]
14711451
14721452 mismatched_keys = []
1473-
1474- assign_to_params_buffers = None
14751453 error_msgs = []
14761454
14771455 # Deal with offload
@@ -1499,63 +1477,45 @@ def _load_pretrained_model(
14991477 # load_state_dict will manage the case where we pass a dict instead of a file
15001478 # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
15011479 resolved_model_file = [state_dict ]
1480+ is_file = not isinstance (state_dict , dict )
15021481
1503- if len (resolved_model_file ) > 1 :
1504- resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
1505-
1506- for shard_file in resolved_model_file :
1507- state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
1508-
1509- def _find_mismatched_keys (
1510- state_dict ,
1511- model_state_dict ,
1512- loaded_keys ,
1513- ignore_mismatched_sizes ,
1514- ):
1515- mismatched_keys = []
1516- if ignore_mismatched_sizes :
1517- for checkpoint_key in loaded_keys :
1518- model_key = checkpoint_key
1519- # If the checkpoint is sharded, we may not have the key here.
1520- if checkpoint_key not in state_dict :
1521- continue
1522-
1523- if (
1524- model_key in model_state_dict
1525- and state_dict [checkpoint_key ].shape != model_state_dict [model_key ].shape
1526- ):
1527- mismatched_keys .append (
1528- (checkpoint_key , state_dict [checkpoint_key ].shape , model_state_dict [model_key ].shape )
1529- )
1530- del state_dict [checkpoint_key ]
1531- return mismatched_keys
1532-
1533- mismatched_keys += _find_mismatched_keys (
1534- state_dict ,
1482+ # prepare the arguments.
1483+ args_list = [
1484+ (
1485+ model ,
15351486 model_state_dict ,
1487+ shard_file ,
1488+ device_map ,
1489+ dtype ,
1490+ hf_quantizer ,
1491+ keep_in_fp32_modules ,
1492+ dduf_entries ,
15361493 loaded_keys ,
1494+ unexpected_keys ,
1495+ offload_index ,
1496+ offload_folder ,
1497+ state_dict_index ,
1498+ state_dict_folder ,
15371499 ignore_mismatched_sizes ,
1500+ low_cpu_mem_usage ,
15381501 )
1502+ for shard_file in resolved_model_file
1503+ ]
15391504
1540- if low_cpu_mem_usage :
1541- offload_index , state_dict_index = load_model_dict_into_meta (
1542- model ,
1543- state_dict ,
1544- device_map = device_map ,
1545- dtype = dtype ,
1546- hf_quantizer = hf_quantizer ,
1547- keep_in_fp32_modules = keep_in_fp32_modules ,
1548- unexpected_keys = unexpected_keys ,
1549- offload_folder = offload_folder ,
1550- offload_index = offload_index ,
1551- state_dict_index = state_dict_index ,
1552- state_dict_folder = state_dict_folder ,
1553- )
1554- else :
1555- if assign_to_params_buffers is None :
1556- assign_to_params_buffers = check_support_param_buffer_assignment (model , state_dict )
1505+ if is_parallel_loading_enabled and is_file :
1506+ offload_index , state_dict_index , _mismatched_keys , _error_msgs = load_shard_files_with_threadpool (
1507+ args_list
1508+ )
1509+ error_msgs += _error_msgs
1510+ mismatched_keys += _mismatched_keys
1511+ else :
1512+ if len (args_list ) > 1 :
1513+ args_list = logging .tqdm (args_list , desc = "Loading checkpoint shards" )
15571514
1558- error_msgs += _load_state_dict_into_model (model , state_dict , assign_to_params_buffers )
1515+ for args in args_list :
1516+ offload_index , state_dict_index , _error_msgs = load_shard_file (args )
1517+ error_msgs += _error_msgs
1518+ mismatched_keys += _mismatched_keys
15591519
15601520 if offload_index is not None and len (offload_index ) > 0 :
15611521 save_offload_index (offload_index , offload_folder )
0 commit comments