- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Enhance Model Loading By Providing Parallelism, Uses Optional Env Flag #36835
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
8fb9b18
              27f36f2
              7e5ecd8
              e7c3ea5
              7599fe2
              065e102
              d31594a
              33b3e0f
              3fb6b65
              0e22c04
              904bdaf
              7e37ba4
              a203f6a
              14e9eef
              fe1fc0c
              d5637e8
              e0d37bb
              9b4165c
              1085461
              82ab2ec
              7ae3db6
              8d04325
              674ec37
              b8a1470
              efb6605
              c66daef
              4566c5c
              610c5e3
              a9cb54b
              fc76fbb
              16f3751
              cd0f42e
              3b9f458
              b6bf421
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| <!--Copyright 2020 The HuggingFace Team. All rights reserved. | ||
|  | ||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|  | ||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | ||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|  | ||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|  | ||
| --> | ||
|  | ||
| # Environment Variables | ||
|  | ||
| ## HF_ENABLE_PARALLEL_LOADING | ||
|  | ||
| By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. | ||
|  | ||
| Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. | ||
|  | ||
| e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. | ||
|  | ||
| Profile before committing to using this environment variable, this will not produce speed ups for smaller models. | ||
|  | ||
| NOTE, if you are not loading a model onto specifically the CPU, you must set `multiprocessing` to use the `spawn` start method like so: | ||
|  | ||
| ```py | ||
| import os | ||
|  | ||
| os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" | ||
|  | ||
| import multiprocessing | ||
| from transformers import pipeline | ||
|  | ||
| if __name__ == "__main__": | ||
| # NOTE if a model loads on CPU this is not required | ||
| multiprocessing.set_start_method("spawn", force=True) | ||
|  | ||
| model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
| ``` | ||
|  | ||
| If loading onto a cuda device, the code will crash if multiprocessing.set_start_method("spawn", force=True) is not set. | ||
|  | ||
| ## HF_PARALLEL_LOADING_WORKERS | ||
|  | ||
| Determines how many child processes should be used when parallel loading is enabled. Default is `8`. Tune as you see fit. | ||
|  | ||
| ```py | ||
| import os | ||
|  | ||
| os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" | ||
| os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" | ||
|  | ||
| import multiprocessing | ||
| from transformers import pipeline | ||
|  | ||
| if __name__ == "__main__": | ||
| # NOTE if a model loads on CPU this is not required | ||
| multiprocessing.set_start_method("spawn", force=True) | ||
|  | ||
| model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -874,6 +874,128 @@ def _load_state_dict_into_meta_model( | |
| return disk_offload_index, cpu_offload_index | ||
|  | ||
|  | ||
| def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): | ||
| state_dict_modules = {} | ||
|  | ||
| for tensor_name in state_dict.keys(): | ||
| if tensor_name not in expected_keys: | ||
| continue | ||
|  | ||
| splits = tensor_name.split(".") | ||
| module = model_to_load | ||
| for split in splits: | ||
| try: | ||
| module = getattr(module, split) | ||
| except Exception as exception: | ||
| print(exception) | ||
| pass | ||
|         
                  inf3rnus marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| state_dict_modules[tensor_name] = module | ||
|  | ||
| return state_dict_modules | ||
|  | ||
|  | ||
| # This function is in global scope so it's picklable for multiprocessing | ||
| def load_shard_file(args): | ||
| ( | ||
| state_dict, | ||
| shard_file, | ||
| disk_only_shard_files, | ||
| low_cpu_mem_usage, | ||
| is_quantized, | ||
| device_map, | ||
| hf_quantizer, | ||
| key_renaming_mapping, | ||
| weights_only, | ||
| model_to_load, | ||
| ignore_mismatched_sizes, | ||
| prefix, | ||
| loading_base_model_from_task_state_dict, | ||
| expected_keys, | ||
| reverse_key_renaming_mapping, | ||
| disk_offload_folder, | ||
| disk_offload_index, | ||
| cpu_offload_folder, | ||
| cpu_offload_index, | ||
| is_offloaded_safetensors, | ||
| keep_in_fp32_modules, | ||
| unexpected_keys, | ||
| device_mesh, | ||
| ) = args | ||
| # Skip the load for shards that only contain disk-offloaded weights | ||
| if shard_file in disk_only_shard_files: | ||
| return [], [], disk_offload_index, cpu_offload_index, {} | ||
|  | ||
| map_location = "cpu" | ||
| if low_cpu_mem_usage: | ||
| if shard_file.endswith(".safetensors") and not is_quantized: | ||
| map_location = "meta" | ||
| elif ( | ||
| device_map is not None | ||
| and hf_quantizer is not None | ||
| and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
| and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
| ): | ||
| map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
|  | ||
| # If shard_file is "", we use the existing state_dict instead of loading it | ||
| if shard_file != "": | ||
| state_dict = load_state_dict( | ||
| shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only | ||
| ) | ||
|  | ||
| error_msgs = [] | ||
| mismatched_keys = [] | ||
|  | ||
| # Fix the key names | ||
| state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||
|  | ||
| # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
| # matching the weights in the model. | ||
| mismatched_keys += _find_mismatched_keys( | ||
| model_to_load, | ||
| state_dict, | ||
| ignore_mismatched_sizes, | ||
| prefix if loading_base_model_from_task_state_dict else "", | ||
| ) | ||
|  | ||
| if low_cpu_mem_usage: | ||
| # Skip it with fsdp on ranks other than 0 | ||
| if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): | ||
| disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( | ||
| model_to_load, | ||
| state_dict, | ||
| shard_file, | ||
| expected_keys, | ||
| reverse_key_renaming_mapping, | ||
| device_map=device_map, | ||
| disk_offload_folder=disk_offload_folder, | ||
| disk_offload_index=disk_offload_index, | ||
| cpu_offload_folder=cpu_offload_folder, | ||
| cpu_offload_index=cpu_offload_index, | ||
| hf_quantizer=hf_quantizer, | ||
| is_safetensors=is_offloaded_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| ) | ||
| else: | ||
| assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) | ||
| if is_deepspeed_zero3_enabled(): | ||
| error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) | ||
| else: | ||
| model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) | ||
|  | ||
| # We now figure out what in the state dict changed and store the module used for each layer, this will contain the device | ||
| # information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model | ||
| state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys) | ||
|          | ||
|  | ||
| # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop | ||
| del state_dict | ||
|  | ||
| return mismatched_keys, error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules | ||
|  | ||
|  | ||
| def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: | ||
| if variant is not None: | ||
| path, name = weights_name.rsplit(".", 1) | ||
|  | @@ -4810,9 +4932,6 @@ def _load_pretrained_model( | |
| cpu_offload_folder = tempfile.mkdtemp() | ||
| cpu_offload_index = {} | ||
|  | ||
| # For nice tqdm bars | ||
| if checkpoint_files is not None and len(checkpoint_files) > 1: | ||
| checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") | ||
| # To be able to iterate, even if we don't use it if the state_dict is already provided | ||
| elif state_dict is not None: | ||
| checkpoint_files = [""] | ||
|  | @@ -4827,73 +4946,99 @@ def _load_pretrained_model( | |
| expanded_device_map = expand_device_map(device_map, expected_keys) | ||
| caching_allocator_warmup(model_to_load, expanded_device_map) | ||
|  | ||
| error_msgs = [] | ||
| from multiprocessing import Pool | ||
|  | ||
| # Prepare arguments for multiprocessing | ||
| args_list = [ | ||
| ( | ||
| state_dict, | ||
| shard_file, | ||
| disk_only_shard_files, | ||
| low_cpu_mem_usage, | ||
| is_quantized, | ||
| device_map, | ||
| hf_quantizer, | ||
| key_renaming_mapping, | ||
| weights_only, | ||
| model_to_load, | ||
| ignore_mismatched_sizes, | ||
| prefix, | ||
| loading_base_model_from_task_state_dict, | ||
| expected_keys, | ||
| reverse_key_renaming_mapping, | ||
| disk_offload_folder, | ||
| disk_offload_index, | ||
| cpu_offload_folder, | ||
| cpu_offload_index, | ||
| is_offloaded_safetensors, | ||
| keep_in_fp32_modules, | ||
| unexpected_keys, | ||
| device_mesh, | ||
| ) | ||
| for shard_file in checkpoint_files | ||
| ] | ||
|  | ||
| mismatched_keys = [] | ||
| # Iterate on all the shards to load the weights | ||
| for shard_file in checkpoint_files: | ||
| # Skip the load for shards that only contain disk-offloaded weights | ||
| if shard_file in disk_only_shard_files: | ||
| continue | ||
| error_msgs = [] | ||
|  | ||
| map_location = "cpu" | ||
| if low_cpu_mem_usage: | ||
| if shard_file.endswith(".safetensors") and not is_quantized: | ||
| map_location = "meta" | ||
| elif ( | ||
| device_map is not None | ||
| and hf_quantizer is not None | ||
| and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
| and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
| ): | ||
| map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
| # Use multiprocessing Pool for parallel execution, off by default | ||
| if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): | ||
|          | ||
| num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) | ||
| logger.info(f"Loading model weights in parallel with {num_workers} workers...") | ||
|          | ||
| state_dict_modules_list = [] | ||
|  | ||
| with Pool(processes=num_workers) as pool: | ||
| # For nice tqdm bars | ||
| with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: | ||
| # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model | ||
| for result in pool.imap_unordered(load_shard_file, args_list): | ||
| _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( | ||
| result | ||
| ) | ||
|  | ||
| # If shard_file is "", we use the existing state_dict instead of loading it | ||
| if shard_file != "": | ||
| state_dict = load_state_dict( | ||
| shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only | ||
| ) | ||
| mismatched_keys += _mismatched_keys | ||
| error_msgs += _error_msgs | ||
|  | ||
| # Fix the key names | ||
| state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} | ||
| state_dict_modules_list.append(state_dict_modules) | ||
|  | ||
| # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
| # matching the weights in the model. | ||
| mismatched_keys += _find_mismatched_keys( | ||
| model_to_load, | ||
| state_dict, | ||
| ignore_mismatched_sizes, | ||
| prefix if loading_base_model_from_task_state_dict else "", | ||
| ) | ||
| pbar.update(1) | ||
|  | ||
| if low_cpu_mem_usage: | ||
| # Skip it with fsdp on ranks other than 0 | ||
| if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): | ||
| disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( | ||
| model_to_load, | ||
| state_dict, | ||
| shard_file, | ||
| expected_keys, | ||
| reverse_key_renaming_mapping, | ||
| device_map=device_map, | ||
| disk_offload_folder=disk_offload_folder, | ||
| disk_offload_index=disk_offload_index, | ||
| cpu_offload_folder=cpu_offload_folder, | ||
| cpu_offload_index=cpu_offload_index, | ||
| hf_quantizer=hf_quantizer, | ||
| is_safetensors=is_offloaded_safetensors, | ||
| keep_in_fp32_modules=keep_in_fp32_modules, | ||
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| ) | ||
| else: | ||
| assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) | ||
| if is_deepspeed_zero3_enabled(): | ||
| error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) | ||
| else: | ||
| model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) | ||
| # We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker | ||
| # We are transferring that state into the orginal ref (model_to_load) here | ||
| # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors | ||
| # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust | ||
| # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct | ||
| for state_dict_modules in state_dict_modules_list: | ||
| for tensor_name in state_dict_modules.keys(): | ||
| splits = tensor_name.split(".") | ||
| module = model_to_load | ||
|  | ||
| for split in splits[:-1]: | ||
| module = getattr(module, split) | ||
|  | ||
| last_key = splits.pop() | ||
|          | ||
|  | ||
| # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop | ||
| del state_dict | ||
| tensor_ref = state_dict_modules[tensor_name] | ||
|  | ||
| setattr(module, last_key, tensor_ref) | ||
|  | ||
| del state_dict_modules_list | ||
| gc.collect() | ||
| else: | ||
| if len(args_list) > 1: | ||
| # For nice tqdm bars | ||
| args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") | ||
|  | ||
| for args in args_list: | ||
| _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( | ||
| load_shard_file(args) | ||
| ) | ||
|  | ||
| mismatched_keys += _mismatched_keys | ||
| error_msgs += _error_msgs | ||
|  | ||
| del state_dict_modules | ||
| gc.collect() | ||
|  | ||
| # Adjust offloaded weights name and save if needed | ||
| if disk_offload_index is not None and len(disk_offload_index) > 0: | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.