- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
[core] parallel loading of shards #12028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
af72ece
              d4e2976
              c9b680d
              ab84d5a
              536df5a
              04cd5cc
              cb0b3ed
              2fdc091
              6d15594
              d34f426
              35e859b
              2cc83b8
              9844c10
              73fb972
              04bff1c
              cd13977
              8968e2f
              dca6388
              e276f08
              ad2dd62
              36c86d2
              ae2561b
              f0eec0d
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -20,6 +20,7 @@ | |
| import os | ||
| from array import array | ||
| from collections import OrderedDict, defaultdict | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from pathlib import Path | ||
| from typing import Dict, List, Optional, Union | ||
| from zipfile import is_zipfile | ||
|  | @@ -310,6 +311,130 @@ def load_model_dict_into_meta( | |
| return offload_index, state_dict_index | ||
|  | ||
|  | ||
| def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | ||
| """ | ||
| Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first | ||
| checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's | ||
| parameters. | ||
|  | ||
| """ | ||
| if model_to_load.device.type == "meta": | ||
| return False | ||
|  | ||
| if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | ||
| return False | ||
|  | ||
| # Some models explicitly do not support param buffer assignment | ||
| if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | ||
| logger.debug( | ||
| f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | ||
| ) | ||
| return False | ||
|  | ||
| # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | ||
| first_key = next(iter(model_to_load.state_dict().keys())) | ||
| if start_prefix + first_key in state_dict: | ||
| return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | ||
|  | ||
| return False | ||
|  | ||
|  | ||
| def load_shard_file(args): | ||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| ( | ||
| model, | ||
| model_state_dict, | ||
| shard_file, | ||
| device_map, | ||
| dtype, | ||
| hf_quantizer, | ||
| keep_in_fp32_modules, | ||
| dduf_entries, | ||
| loaded_keys, | ||
| unexpected_keys, | ||
| offload_index, | ||
| offload_folder, | ||
| state_dict_index, | ||
| state_dict_folder, | ||
| ignore_mismatched_sizes, | ||
| low_cpu_mem_usage, | ||
| ) = args | ||
| assign_to_params_buffers = None | ||
| state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) | ||
| mismatched_keys = _find_mismatched_keys( | ||
| state_dict, | ||
| model_state_dict, | ||
| loaded_keys, | ||
| ignore_mismatched_sizes, | ||
| ) | ||
| error_msgs = [] | ||
| if low_cpu_mem_usage: | ||
| offload_index, state_dict_index = load_model_dict_into_meta( | ||
| model, | ||
| state_dict, | ||
| device_map=device_map, | ||
| dtype=dtype, | ||
| hf_quantizer=hf_quantizer, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| offload_folder=offload_folder, | ||
| offload_index=offload_index, | ||
| state_dict_index=state_dict_index, | ||
| state_dict_folder=state_dict_folder, | ||
| ) | ||
| else: | ||
| if assign_to_params_buffers is None: | ||
| assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) | ||
|  | ||
| error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) | ||
| return offload_index, state_dict_index, mismatched_keys, error_msgs | ||
|  | ||
|  | ||
| def load_shard_files_with_threadpool(args_list): | ||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) | ||
|          | ||
|  | ||
| # Do not spawn anymore workers than you need | ||
| num_workers = min(len(args_list), num_workers) | ||
|  | ||
| logger.info(f"Loading model weights in parallel with {num_workers} workers...") | ||
|  | ||
| error_msgs = [] | ||
| mismatched_keys = [] | ||
|  | ||
| with ThreadPoolExecutor(max_workers=num_workers) as executor: | ||
| with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: | ||
| futures = [executor.submit(load_shard_file, arg) for arg in args_list] | ||
| for future in as_completed(futures): | ||
| result = future.result() | ||
| offload_index, state_dict_index, _mismatched_keys, _error_msgs = result | ||
| error_msgs += _error_msgs | ||
| mismatched_keys += _mismatched_keys | ||
| pbar.update(1) | ||
|  | ||
| return offload_index, state_dict_index, mismatched_keys, error_msgs | ||
|  | ||
|  | ||
| def _find_mismatched_keys( | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. Moved it out of  | ||
| state_dict, | ||
| model_state_dict, | ||
| loaded_keys, | ||
| ignore_mismatched_sizes, | ||
| ): | ||
| mismatched_keys = [] | ||
| if ignore_mismatched_sizes: | ||
| for checkpoint_key in loaded_keys: | ||
| model_key = checkpoint_key | ||
| # If the checkpoint is sharded, we may not have the key here. | ||
| if checkpoint_key not in state_dict: | ||
| continue | ||
|  | ||
| if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape: | ||
| mismatched_keys.append( | ||
| (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) | ||
| ) | ||
| del state_dict[checkpoint_key] | ||
| return mismatched_keys | ||
|  | ||
|  | ||
| def _load_state_dict_into_model( | ||
| model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False | ||
| ) -> List[str]: | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -41,6 +41,7 @@ | |
| from ..quantizers.quantization_config import QuantizationMethod | ||
| from ..utils import ( | ||
| CONFIG_NAME, | ||
| ENV_VARS_TRUE_VALUES, | ||
| FLAX_WEIGHTS_NAME, | ||
| SAFE_WEIGHTS_INDEX_NAME, | ||
| SAFETENSORS_WEIGHTS_NAME, | ||
|  | @@ -69,9 +70,8 @@ | |
| _expand_device_map, | ||
| _fetch_index_file, | ||
| _fetch_index_file_legacy, | ||
| _find_mismatched_keys, | ||
| _load_state_dict_into_model, | ||
| load_model_dict_into_meta, | ||
| load_shard_file, | ||
| load_shard_files_with_threadpool, | ||
| load_state_dict, | ||
| ) | ||
|  | ||
|  | @@ -208,34 +208,6 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
| return last_tuple[1].dtype | ||
|  | ||
|  | ||
| def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""): | ||
| """ | ||
| Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first | ||
| checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's | ||
| parameters. | ||
|  | ||
| """ | ||
| if model_to_load.device.type == "meta": | ||
| return False | ||
|  | ||
| if len([key for key in state_dict if key.startswith(start_prefix)]) == 0: | ||
| return False | ||
|  | ||
| # Some models explicitly do not support param buffer assignment | ||
| if not getattr(model_to_load, "_supports_param_buffer_assignment", True): | ||
| logger.debug( | ||
| f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower" | ||
| ) | ||
| return False | ||
|  | ||
| # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype | ||
| first_key = next(iter(model_to_load.state_dict().keys())) | ||
| if start_prefix + first_key in state_dict: | ||
| return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype | ||
|  | ||
| return False | ||
|  | ||
|  | ||
| @contextmanager | ||
| def no_init_weights(): | ||
| """ | ||
|  | @@ -988,6 +960,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) | ||
| disable_mmap = kwargs.pop("disable_mmap", False) | ||
|  | ||
| is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES | ||
| if is_parallel_loading_enabled and not low_cpu_mem_usage: | ||
| raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") | ||
|         
                  sayakpaul marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): | ||
| torch_dtype = torch.float32 | ||
| logger.warning( | ||
|  | @@ -1323,6 +1299,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
| hf_quantizer=hf_quantizer, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| dduf_entries=dduf_entries, | ||
| is_parallel_loading_enabled=is_parallel_loading_enabled, | ||
| ) | ||
| loading_info = { | ||
| "missing_keys": missing_keys, | ||
|  | @@ -1518,6 +1495,7 @@ def _load_pretrained_model( | |
| offload_state_dict: Optional[bool] = None, | ||
| offload_folder: Optional[Union[str, os.PathLike]] = None, | ||
| dduf_entries: Optional[Dict[str, DDUFEntry]] = None, | ||
| is_parallel_loading_enabled: Optional[bool] = False, | ||
| ): | ||
| model_state_dict = model.state_dict() | ||
| expected_keys = list(model_state_dict.keys()) | ||
|  | @@ -1531,6 +1509,9 @@ def _load_pretrained_model( | |
| for pat in cls._keys_to_ignore_on_load_unexpected: | ||
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
|  | ||
| mismatched_keys = [] | ||
| error_msgs = [] | ||
|  | ||
| # Deal with offload | ||
| if device_map is not None and "disk" in device_map.values(): | ||
| if offload_folder is None: | ||
|  | @@ -1566,37 +1547,43 @@ def _load_pretrained_model( | |
| # if state dict is not None, it means that we don't need to read the files from resolved_model_file also | ||
| resolved_model_file = [state_dict] | ||
|  | ||
| if len(resolved_model_file) > 1: | ||
| resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") | ||
|  | ||
| mismatched_keys = [] | ||
| assign_to_params_buffers = None | ||
| error_msgs = [] | ||
| # prepare the arguments. | ||
| args_list = [ | ||
| ( | ||
| model, | ||
| model_state_dict, | ||
| shard_file, | ||
| device_map, | ||
| dtype, | ||
| hf_quantizer, | ||
| keep_in_fp32_modules, | ||
| dduf_entries, | ||
| loaded_keys, | ||
| unexpected_keys, | ||
| offload_index, | ||
| offload_folder, | ||
| state_dict_index, | ||
| state_dict_folder, | ||
| ignore_mismatched_sizes, | ||
| low_cpu_mem_usage, | ||
| ) | ||
| for shard_file in resolved_model_file | ||
| ] | ||
|          | ||
|  | ||
| for shard_file in resolved_model_file: | ||
| state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) | ||
| mismatched_keys += _find_mismatched_keys( | ||
| state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes | ||
| if is_parallel_loading_enabled: | ||
| offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_files_with_threadpool( | ||
| args_list | ||
| ) | ||
| error_msgs += _error_msgs | ||
| mismatched_keys += _mismatched_keys | ||
| else: | ||
| if len(args_list) > 1: | ||
| args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") | ||
|  | ||
| if low_cpu_mem_usage: | ||
| offload_index, state_dict_index = load_model_dict_into_meta( | ||
| model, | ||
| state_dict, | ||
| device_map=device_map, | ||
| dtype=dtype, | ||
| hf_quantizer=hf_quantizer, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| offload_folder=offload_folder, | ||
| offload_index=offload_index, | ||
| state_dict_index=state_dict_index, | ||
| state_dict_folder=state_dict_folder, | ||
| ) | ||
| else: | ||
| if assign_to_params_buffers is None: | ||
| assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict) | ||
| error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers) | ||
| for args in args_list: | ||
| offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_shard_file(args) | ||
| error_msgs += _error_msgs | ||
| mismatched_keys += _mismatched_keys | ||
|  | ||
| empty_device_cache() | ||
|  | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it here from
modeling_utils.py.