From 8fb9b1877a9b0001bb33e3d34c85d663376a009f Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 18 Mar 2025 21:18:59 -0400 Subject: [PATCH 01/27] Get parallel loader working. Include tests. --- src/transformers/modeling_utils.py | 271 ++++++++++++++---- tests/utils/test_modeling_utils.py | 39 +-- .../test_modeling_utils_parallel_loading.py | 25 ++ 3 files changed, 254 insertions(+), 81 deletions(-) create mode 100644 tests/utils/test_modeling_utils_parallel_loading.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4158c82b4094..f856ee76ceff 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -874,6 +874,128 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index +def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): + state_dict_modules = {} + + for tensor_name in state_dict.keys(): + if tensor_name not in expected_keys: + continue + + splits = tensor_name.split(".") + module = model_to_load + for split in splits: + try: + module = getattr(module, split) + except Exception as exception: + print(exception) + pass + + state_dict_modules[tensor_name] = module + + return state_dict_modules + + +# This function is in global scope so it's picklable for multiprocessing +def load_shard_file(args): + ( + state_dict, + shard_file, + disk_only_shard_files, + low_cpu_mem_usage, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + ignore_mismatched_sizes, + prefix, + loading_base_model_from_task_state_dict, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_modules, + unexpected_keys, + device_mesh, + ) = args + # Skip the load for shards that only contain disk-offloaded weights + if shard_file in disk_only_shard_files: + return [], [], disk_offload_index, cpu_offload_index, {} + + map_location = "cpu" + if low_cpu_mem_usage: + if shard_file.endswith(".safetensors") and not is_quantized: + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + + # If shard_file is "", we use the existing state_dict instead of loading it + if shard_file != "": + state_dict = load_state_dict( + shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only + ) + + error_msgs = [] + mismatched_keys = [] + + # Fix the key names + state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + model_to_load, + state_dict, + ignore_mismatched_sizes, + prefix if loading_base_model_from_task_state_dict else "", + ) + + if low_cpu_mem_usage: + # Skip it with fsdp on ranks other than 0 + if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + ) + else: + assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) + if is_deepspeed_zero3_enabled(): + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) + else: + model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) + + # We now figure out what in the state dict changed and store the module used for each layer, this will contain the device + # information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model + state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys) + + # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop + del state_dict + + return mismatched_keys, error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules + + def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: path, name = weights_name.rsplit(".", 1) @@ -4810,9 +4932,6 @@ def _load_pretrained_model( cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} - # For nice tqdm bars - if checkpoint_files is not None and len(checkpoint_files) > 1: - checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") # To be able to iterate, even if we don't use it if the state_dict is already provided elif state_dict is not None: checkpoint_files = [""] @@ -4827,73 +4946,99 @@ def _load_pretrained_model( expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map) - error_msgs = [] + from multiprocessing import Pool + + # Prepare arguments for multiprocessing + args_list = [ + ( + state_dict, + shard_file, + disk_only_shard_files, + low_cpu_mem_usage, + is_quantized, + device_map, + hf_quantizer, + key_renaming_mapping, + weights_only, + model_to_load, + ignore_mismatched_sizes, + prefix, + loading_base_model_from_task_state_dict, + expected_keys, + reverse_key_renaming_mapping, + disk_offload_folder, + disk_offload_index, + cpu_offload_folder, + cpu_offload_index, + is_offloaded_safetensors, + keep_in_fp32_modules, + unexpected_keys, + device_mesh, + ) + for shard_file in checkpoint_files + ] + mismatched_keys = [] - # Iterate on all the shards to load the weights - for shard_file in checkpoint_files: - # Skip the load for shards that only contain disk-offloaded weights - if shard_file in disk_only_shard_files: - continue + error_msgs = [] - map_location = "cpu" - if low_cpu_mem_usage: - if shard_file.endswith(".safetensors") and not is_quantized: - map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + # Use multiprocessing Pool for parallel execution, off by default + if json.loads(os.environ.get("ENABLE_PARALLEL_LOADING", "false")): + num_workers = json.loads(os.environ.get("PARALLEL_LOADING_WORKERS", "8")) + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + state_dict_modules_list = [] + + with Pool(processes=num_workers) as pool: + # For nice tqdm bars + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model + for result in pool.imap_unordered(load_shard_file, args_list): + _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( + result + ) - # If shard_file is "", we use the existing state_dict instead of loading it - if shard_file != "": - state_dict = load_state_dict( - shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only - ) + mismatched_keys += _mismatched_keys + error_msgs += _error_msgs - # Fix the key names - state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} + state_dict_modules_list.append(state_dict_modules) - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys += _find_mismatched_keys( - model_to_load, - state_dict, - ignore_mismatched_sizes, - prefix if loading_base_model_from_task_state_dict else "", - ) + pbar.update(1) - if low_cpu_mem_usage: - # Skip it with fsdp on ranks other than 0 - if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - shard_file, - expected_keys, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - cpu_offload_folder=cpu_offload_folder, - cpu_offload_index=cpu_offload_index, - hf_quantizer=hf_quantizer, - is_safetensors=is_offloaded_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - ) - else: - assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) - if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) - else: - model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) + # We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker + # We are transferring that state into the orginal ref (model_to_load) here + # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors + # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust + # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct + for state_dict_modules in state_dict_modules_list: + for tensor_name in state_dict_modules.keys(): + splits = tensor_name.split(".") + module = model_to_load + + for split in splits[:-1]: + module = getattr(module, split) + + last_key = splits.pop() - # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop - del state_dict + tensor_ref = state_dict_modules[tensor_name] + + setattr(module, last_key, tensor_ref) + + del state_dict_modules_list + gc.collect() + else: + if len(args_list) > 1: + # For nice tqdm bars + args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") + + for args in args_list: + _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( + load_shard_file(args) + ) + + mismatched_keys += _mismatched_keys + error_msgs += _error_msgs + + del state_dict_modules + gc.collect() # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 71a400579f63..887e5cdda433 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -298,6 +298,27 @@ def test_local_files_only(self): hub.TRANSFORMERS_CACHE = transformers_cache +# Need to be serializable, which means they cannot be in a test class method +class TestGammaBetaNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.gamma = torch.nn.Parameter(torch.ones(1)) + self.beta = torch.nn.Parameter(torch.zeros(1)) + + def forward(self): + return self.gamma.sum() + self.beta.sum() + + +class TestModelGammaBeta(PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.LayerNorm = TestGammaBetaNorm() + self.post_init() + + def forward(self): + return self.LayerNorm() + + if is_flax_available(): from transformers import FlaxBertModel @@ -1641,24 +1662,6 @@ def test_model_from_pretrained_from_mlx(self): torch.testing.assert_close(outputs_from_saved["logits"], outputs["logits"]) def test_warning_for_beta_gamma_parameters(self): - class TestGammaBetaNorm(torch.nn.Module): - def __init__(self): - super().__init__() - self.gamma = torch.nn.Parameter(torch.ones(1)) - self.beta = torch.nn.Parameter(torch.zeros(1)) - - def forward(self): - return self.gamma.sum() + self.beta.sum() - - class TestModelGammaBeta(PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.LayerNorm = TestGammaBetaNorm() - self.post_init() - - def forward(self): - return self.LayerNorm() - logger = logging.get_logger("transformers.modeling_utils") config = PretrainedConfig() warning_msg_gamma = "`LayerNorm.gamma` -> `LayerNorm.weight`" diff --git a/tests/utils/test_modeling_utils_parallel_loading.py b/tests/utils/test_modeling_utils_parallel_loading.py new file mode 100644 index 000000000000..533747a6f826 --- /dev/null +++ b/tests/utils/test_modeling_utils_parallel_loading.py @@ -0,0 +1,25 @@ +import multiprocessing +import os + +from .test_modeling_utils import ModelUtilsTest + + +original_setUp = ModelUtilsTest.setUp + + +# We're monkey patching the original tests, as we want to run them, but now with the parallel loader enabled +def patched_setUp(self): + # Call the original setUp first + original_setUp(self) + + # Set the env variable to enable parallel loading + os.environ.setdefault("ENABLE_PARALLEL_LOADING", "true") + # Set multiprocessing to spawn, which is required due to torch contraints + multiprocessing.set_start_method("spawn", force=True) + + +try: + # Monkey patch the setUp method + ModelUtilsTest.setUp = patched_setUp +finally: + ModelUtilsTest.setUp = original_setUp From 27f36f2432425ba8ae4e742a52e0aaa50de262eb Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 18 Mar 2025 21:28:55 -0400 Subject: [PATCH 02/27] Update the tests for parallel loading --- .../test_modeling_utils_parallel_loading.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/tests/utils/test_modeling_utils_parallel_loading.py b/tests/utils/test_modeling_utils_parallel_loading.py index 533747a6f826..3f035c56b889 100644 --- a/tests/utils/test_modeling_utils_parallel_loading.py +++ b/tests/utils/test_modeling_utils_parallel_loading.py @@ -1,25 +1,12 @@ import multiprocessing import os -from .test_modeling_utils import ModelUtilsTest +# Set the env variable to enable parallel loading +os.environ["ENABLE_PARALLEL_LOADING"] = "true" -original_setUp = ModelUtilsTest.setUp +# Set multiprocessing to spawn, which is required due to torch contraints +multiprocessing.set_start_method("spawn", force=True) - -# We're monkey patching the original tests, as we want to run them, but now with the parallel loader enabled -def patched_setUp(self): - # Call the original setUp first - original_setUp(self) - - # Set the env variable to enable parallel loading - os.environ.setdefault("ENABLE_PARALLEL_LOADING", "true") - # Set multiprocessing to spawn, which is required due to torch contraints - multiprocessing.set_start_method("spawn", force=True) - - -try: - # Monkey patch the setUp method - ModelUtilsTest.setUp = patched_setUp -finally: - ModelUtilsTest.setUp = original_setUp +# Declare the normal model_utils.py test as a sideffect of importing the module +from .test_modeling_utils import ModelUtilsTest # noqa \ No newline at end of file From e7c3ea52ad8d318ec3dac58566716a20bdf3a2f5 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 19 Mar 2025 13:12:34 -0400 Subject: [PATCH 03/27] Rename env variables. --- src/transformers/modeling_utils.py | 2 +- tests/utils/test_modeling_utils_parallel_loading.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f856ee76ceff..214920cf7c87 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4982,7 +4982,7 @@ def _load_pretrained_model( error_msgs = [] # Use multiprocessing Pool for parallel execution, off by default - if json.loads(os.environ.get("ENABLE_PARALLEL_LOADING", "false")): + if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): num_workers = json.loads(os.environ.get("PARALLEL_LOADING_WORKERS", "8")) logger.info(f"Loading model weights in parallel with {num_workers} workers...") state_dict_modules_list = [] diff --git a/tests/utils/test_modeling_utils_parallel_loading.py b/tests/utils/test_modeling_utils_parallel_loading.py index 3f035c56b889..145456413f00 100644 --- a/tests/utils/test_modeling_utils_parallel_loading.py +++ b/tests/utils/test_modeling_utils_parallel_loading.py @@ -3,10 +3,10 @@ # Set the env variable to enable parallel loading -os.environ["ENABLE_PARALLEL_LOADING"] = "true" +os.environ["HF_HF_ENABLE_PARALLEL_LOADING"] = "true" # Set multiprocessing to spawn, which is required due to torch contraints multiprocessing.set_start_method("spawn", force=True) # Declare the normal model_utils.py test as a sideffect of importing the module -from .test_modeling_utils import ModelUtilsTest # noqa \ No newline at end of file +from .test_modeling_utils import ModelUtilsTest # noqa From 7599fe25d3d7b16366740d647736e33555f5e8fc Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 19 Mar 2025 14:38:50 -0400 Subject: [PATCH 04/27] Add docs for parallel model weight loading. --- docs/source/en/_toctree.yml | 5 ++ .../en/reference/environment_variables.md | 66 +++++++++++++++++++ 2 files changed, 71 insertions(+) create mode 100644 docs/source/en/reference/environment_variables.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 79f8eb3d490d..ad1704d452c5 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1057,4 +1057,9 @@ - local: internal/time_series_utils title: Utilities for Time Series title: Internal helpers + - sections: + - local: reference/environment_variables + title: Environment Variables + title: Reference title: API + diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md new file mode 100644 index 000000000000..b4b0a850a1ae --- /dev/null +++ b/docs/source/en/reference/environment_variables.md @@ -0,0 +1,66 @@ + + +# Environment Variables + +## HF_ENABLE_PARALLEL_LOADING + +Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. + +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"` + +e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. + +Profile before committing to using this environment variable, this will not produce speed ups for smaller models. + +NOTE, if you are not loading a model onto specifically the CPU, you must set `multiprocessing` to use the `spawn` start method like so: + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" + +import multiprocessing +from transformers import pipeline + +if __name__ == "__main__": + # NOTE if a model loads on CPU this is not required + multiprocessing.set_start_method("spawn", force=True) + + model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` + +If loading onto a cuda device, the code will crash if multiprocessing.set_start_method("spawn", force=True) is not set. + +## HF_PARALLEL_LOADING_WORKERS + +Determines how many child processes should be used when parallel loading is enabled. Default is `8`. Tune as you see fit. + +```py +import os + +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" +os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" + +import multiprocessing +from transformers import pipeline + +if __name__ == "__main__": + # NOTE if a model loads on CPU this is not required + multiprocessing.set_start_method("spawn", force=True) + + model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +``` From 065e1022e2b572b6638043efb2fa408d4cb49298 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 19 Mar 2025 14:39:43 -0400 Subject: [PATCH 05/27] Touch up parallel model loading docs. --- docs/source/en/reference/environment_variables.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index b4b0a850a1ae..479046f7f09a 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -20,7 +20,7 @@ rendered properly in your Markdown viewer. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. -Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"` +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. By default it is disabled. e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. From d31594a769683bf48d92e1ec1d1982d0131b187d Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 19 Mar 2025 14:40:36 -0400 Subject: [PATCH 06/27] Touch up parallel model loading docs again. --- docs/source/en/reference/environment_variables.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index 479046f7f09a..e246a44e9ee0 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -18,9 +18,9 @@ rendered properly in your Markdown viewer. ## HF_ENABLE_PARALLEL_LOADING -Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. +By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. -Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. By default it is disabled. +Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. From 33b3e0f04b72319549e6c74790caa5a37f8b39d4 Mon Sep 17 00:00:00 2001 From: Aaron V Date: Wed, 19 Mar 2025 14:43:37 -0400 Subject: [PATCH 07/27] Edit comment in test_modeling_utils_parallel_loading.py --- tests/utils/test_modeling_utils_parallel_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_modeling_utils_parallel_loading.py b/tests/utils/test_modeling_utils_parallel_loading.py index 145456413f00..a1d2f6dbd7a6 100644 --- a/tests/utils/test_modeling_utils_parallel_loading.py +++ b/tests/utils/test_modeling_utils_parallel_loading.py @@ -5,7 +5,7 @@ # Set the env variable to enable parallel loading os.environ["HF_HF_ENABLE_PARALLEL_LOADING"] = "true" -# Set multiprocessing to spawn, which is required due to torch contraints +# Set multiprocessing to spawn, which is required due to cuda constraints multiprocessing.set_start_method("spawn", force=True) # Declare the normal model_utils.py test as a sideffect of importing the module From 0e22c047f18bc3f47b1c32969bc89c08f07c7d33 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 19 Mar 2025 15:05:37 -0400 Subject: [PATCH 08/27] Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modeling_utils.py --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 214920cf7c87..ebe768184493 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4983,7 +4983,7 @@ def _load_pretrained_model( # Use multiprocessing Pool for parallel execution, off by default if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): - num_workers = json.loads(os.environ.get("PARALLEL_LOADING_WORKERS", "8")) + num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) logger.info(f"Loading model weights in parallel with {num_workers} workers...") state_dict_modules_list = [] From 904bdaf67a3265001314350c3895fbdb29654566 Mon Sep 17 00:00:00 2001 From: Aaron V Date: Thu, 20 Mar 2025 22:10:09 -0400 Subject: [PATCH 09/27] Correct times for parallelized loading, previous times were for a "hot" filesystem --- docs/source/en/reference/environment_variables.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index e246a44e9ee0..0929b43e3525 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -18,11 +18,11 @@ rendered properly in your Markdown viewer. ## HF_ENABLE_PARALLEL_LOADING -By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups of greater than 50%. +By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%. Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`. -e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~20s with this enabled vs ~45s without it. +e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it. Profile before committing to using this environment variable, this will not produce speed ups for smaller models. From 7e37ba4be2749a29dd308f384d4cee65f6479a2c Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 24 Mar 2025 17:41:59 -0400 Subject: [PATCH 10/27] Update parallel model loading so the spawn method is encapsulated. DRY up the code by leveraging get_submodule. --- src/transformers/modeling_utils.py | 73 ++++++++++--------- .../test_modeling_utils_parallel_loading.py | 5 +- 2 files changed, 41 insertions(+), 37 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ebe768184493..35d0ce295515 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -22,6 +22,7 @@ import itertools import json import math +import multiprocessing import os import re import shutil @@ -4946,8 +4947,6 @@ def _load_pretrained_model( expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map) - from multiprocessing import Pool - # Prepare arguments for multiprocessing args_list = [ ( @@ -4983,47 +4982,55 @@ def _load_pretrained_model( # Use multiprocessing Pool for parallel execution, off by default if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): - num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - logger.info(f"Loading model weights in parallel with {num_workers} workers...") - state_dict_modules_list = [] + original_start_method = multiprocessing.get_start_method(allow_none=True) - with Pool(processes=num_workers) as pool: - # For nice tqdm bars - with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: - # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model - for result in pool.imap_unordered(load_shard_file, args_list): - _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( - result - ) + try: + # CUDA requires the start method to be spawn, fork creates multiple copies of the cuda runtime, which throws + multiprocessing.set_start_method("spawn", force=True) - mismatched_keys += _mismatched_keys - error_msgs += _error_msgs + num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - state_dict_modules_list.append(state_dict_modules) + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) - pbar.update(1) + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + state_dict_modules_list = [] - # We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker - # We are transferring that state into the orginal ref (model_to_load) here - # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors - # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust - # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct - for state_dict_modules in state_dict_modules_list: - for tensor_name in state_dict_modules.keys(): - splits = tensor_name.split(".") - module = model_to_load + with multiprocessing.Pool(processes=num_workers) as pool: + # For nice tqdm bars + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model + for result in pool.imap_unordered(load_shard_file, args_list): + _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( + result + ) + + mismatched_keys += _mismatched_keys + error_msgs += _error_msgs - for split in splits[:-1]: - module = getattr(module, split) + state_dict_modules_list.append(state_dict_modules) - last_key = splits.pop() + pbar.update(1) - tensor_ref = state_dict_modules[tensor_name] + # We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker + # We are transferring that state into the orginal ref (model_to_load) here + # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors + # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust + # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct + for state_dict_modules in state_dict_modules_list: + for full_name, param in state_dict_modules.items(): + *module_path, attr_name = full_name.split(".") + module_path = '.'.join(module_path) + module = model_to_load.get_submodule(module_path) + setattr(module, attr_name, param) + + del state_dict_modules_list + gc.collect() + finally: + # Restore the start method to prevent side effects for other code that may be running + multiprocessing.set_start_method(original_start_method, force=True) - setattr(module, last_key, tensor_ref) - del state_dict_modules_list - gc.collect() else: if len(args_list) > 1: # For nice tqdm bars diff --git a/tests/utils/test_modeling_utils_parallel_loading.py b/tests/utils/test_modeling_utils_parallel_loading.py index a1d2f6dbd7a6..eda625230448 100644 --- a/tests/utils/test_modeling_utils_parallel_loading.py +++ b/tests/utils/test_modeling_utils_parallel_loading.py @@ -1,12 +1,9 @@ -import multiprocessing import os # Set the env variable to enable parallel loading -os.environ["HF_HF_ENABLE_PARALLEL_LOADING"] = "true" +os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" -# Set multiprocessing to spawn, which is required due to cuda constraints -multiprocessing.set_start_method("spawn", force=True) # Declare the normal model_utils.py test as a sideffect of importing the module from .test_modeling_utils import ModelUtilsTest # noqa From a203f6a13c596ca5c9f51c84343848f11f3ff070 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 24 Mar 2025 17:45:30 -0400 Subject: [PATCH 11/27] Update docs on model loading parallelism so that details on setting the multiprocessing start method are removed, now that the package handles this step internally. --- .../en/reference/environment_variables.md | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index 0929b43e3525..fc09a98b5aec 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -26,8 +26,6 @@ e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in Profile before committing to using this environment variable, this will not produce speed ups for smaller models. -NOTE, if you are not loading a model onto specifically the CPU, you must set `multiprocessing` to use the `spawn` start method like so: - ```py import os @@ -36,18 +34,18 @@ os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" import multiprocessing from transformers import pipeline -if __name__ == "__main__": - # NOTE if a model loads on CPU this is not required - multiprocessing.set_start_method("spawn", force=True) - - model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") ``` -If loading onto a cuda device, the code will crash if multiprocessing.set_start_method("spawn", force=True) is not set. - ## HF_PARALLEL_LOADING_WORKERS -Determines how many child processes should be used when parallel loading is enabled. Default is `8`. Tune as you see fit. +Determines how many child processes should be used when parallel loading is enabled. Default is `8`. + +If the number of files that are being loaded is less than the number of child processes specified, the number that is actually spawned will be equal to the number of files. + +e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned. + +Tune as you see fit. ```py import os @@ -58,9 +56,5 @@ os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" import multiprocessing from transformers import pipeline -if __name__ == "__main__": - # NOTE if a model loads on CPU this is not required - multiprocessing.set_start_method("spawn", force=True) - - model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") +model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") ``` From 14e9eefb1cc51ca75aba9bbc2d1dbe24535ed193 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Mon, 24 Mar 2025 17:49:07 -0400 Subject: [PATCH 12/27] Fix style on model loading parallelism changes. --- src/transformers/modeling_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 35d0ce295515..6ba16f5777f5 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -5001,9 +5001,13 @@ def _load_pretrained_model( with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model for result in pool.imap_unordered(load_shard_file, args_list): - _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( - result - ) + ( + _mismatched_keys, + _error_msgs, + disk_offload_index, + cpu_offload_index, + state_dict_modules, + ) = result mismatched_keys += _mismatched_keys error_msgs += _error_msgs @@ -5020,7 +5024,7 @@ def _load_pretrained_model( for state_dict_modules in state_dict_modules_list: for full_name, param in state_dict_modules.items(): *module_path, attr_name = full_name.split(".") - module_path = '.'.join(module_path) + module_path = ".".join(module_path) module = model_to_load.get_submodule(module_path) setattr(module, attr_name, param) @@ -5030,7 +5034,6 @@ def _load_pretrained_model( # Restore the start method to prevent side effects for other code that may be running multiprocessing.set_start_method(original_start_method, force=True) - else: if len(args_list) > 1: # For nice tqdm bars From d5637e8bcadc9f29223e7ac4f653c4f4653dd676 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 8 Apr 2025 17:44:35 -0400 Subject: [PATCH 13/27] Merge latest version of master's modeling_utils. --- src/transformers/modeling_utils.py | 118 ++++++++++++----------------- 1 file changed, 50 insertions(+), 68 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 412c3bb17d32..47a859e639da 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -830,23 +830,18 @@ def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): return state_dict_modules - # This function is in global scope so it's picklable for multiprocessing def load_shard_file(args): ( - state_dict, shard_file, disk_only_shard_files, - low_cpu_mem_usage, + is_hqq_or_bnb, is_quantized, device_map, hf_quantizer, key_renaming_mapping, weights_only, model_to_load, - ignore_mismatched_sizes, - prefix, - loading_base_model_from_task_state_dict, expected_keys, reverse_key_renaming_mapping, disk_offload_folder, @@ -854,25 +849,32 @@ def load_shard_file(args): cpu_offload_folder, cpu_offload_index, is_offloaded_safetensors, - keep_in_fp32_modules, + keep_in_fp32_regex, unexpected_keys, - device_mesh, - ) = args + device_mesh + ) = args + # Skip the load for shards that only contain disk-offloaded weights if shard_file in disk_only_shard_files: - return [], [], disk_offload_index, cpu_offload_index, {} + return [], disk_offload_index, cpu_offload_index, {} map_location = "cpu" - if low_cpu_mem_usage: - if shard_file.endswith(".safetensors") and not is_quantized: - map_location = "meta" - elif ( - device_map is not None - and hf_quantizer is not None - and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO - and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] - ): - map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + if ( + shard_file.endswith(".safetensors") + and not is_hqq_or_bnb + and not (is_deepspeed_zero3_enabled() and not is_quantized) + ): + map_location = "meta" + elif ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and ( + hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] + or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig) + ) + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) # If shard_file is "", we use the existing state_dict instead of loading it if shard_file != "": @@ -880,48 +882,32 @@ def load_shard_file(args): shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only ) - error_msgs = [] - mismatched_keys = [] - # Fix the key names state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping} - # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not - # matching the weights in the model. - mismatched_keys += _find_mismatched_keys( - model_to_load, - state_dict, - ignore_mismatched_sizes, - prefix if loading_base_model_from_task_state_dict else "", - ) - - if low_cpu_mem_usage: - # Skip it with fsdp on ranks other than 0 - if not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): - disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - shard_file, - expected_keys, - reverse_key_renaming_mapping, - device_map=device_map, - disk_offload_folder=disk_offload_folder, - disk_offload_index=disk_offload_index, - cpu_offload_folder=cpu_offload_folder, - cpu_offload_index=cpu_offload_index, - hf_quantizer=hf_quantizer, - is_safetensors=is_offloaded_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - unexpected_keys=unexpected_keys, - device_mesh=device_mesh, - ) - else: - assign_params = check_support_param_buffer_assignment(model_to_load, state_dict) - if is_deepspeed_zero3_enabled(): - error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_params) - else: - model_to_load.load_state_dict(state_dict, strict=False, assign=assign_params) + error_msgs = [] + if is_deepspeed_zero3_enabled() and not is_quantized: + error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict) + # Skip it with fsdp on ranks other than 0 + elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized): + disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + shard_file, + expected_keys, + reverse_key_renaming_mapping, + device_map=device_map, + disk_offload_folder=disk_offload_folder, + disk_offload_index=disk_offload_index, + cpu_offload_folder=cpu_offload_folder, + cpu_offload_index=cpu_offload_index, + hf_quantizer=hf_quantizer, + is_safetensors=is_offloaded_safetensors, + keep_in_fp32_regex=keep_in_fp32_regex, + unexpected_keys=unexpected_keys, + device_mesh=device_mesh, + ) # We now figure out what in the state dict changed and store the module used for each layer, this will contain the device # information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys) @@ -929,8 +915,7 @@ def load_shard_file(args): # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop del state_dict - return mismatched_keys, error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules - + return error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: @@ -4857,6 +4842,9 @@ def _load_pretrained_model( cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} + # For nice tqdm bars + if checkpoint_files is not None and len(checkpoint_files) > 1: + checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") # To be able to iterate, even if we don't use it if the state_dict is already provided elif state_dict is not None: checkpoint_files = [""] @@ -4877,7 +4865,6 @@ def _load_pretrained_model( state_dict, shard_file, disk_only_shard_files, - low_cpu_mem_usage, is_quantized, device_map, hf_quantizer, @@ -4894,7 +4881,6 @@ def _load_pretrained_model( cpu_offload_folder, cpu_offload_index, is_offloaded_safetensors, - keep_in_fp32_modules, unexpected_keys, device_mesh, ) @@ -4926,14 +4912,12 @@ def _load_pretrained_model( # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model for result in pool.imap_unordered(load_shard_file, args_list): ( - _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules, ) = result - mismatched_keys += _mismatched_keys error_msgs += _error_msgs state_dict_modules_list.append(state_dict_modules) @@ -4964,11 +4948,9 @@ def _load_pretrained_model( args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") for args in args_list: - _mismatched_keys, _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( + _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( load_shard_file(args) ) - - mismatched_keys += _mismatched_keys error_msgs += _error_msgs del state_dict_modules @@ -5986,4 +5968,4 @@ def valid_keys(self) -> List[str]: # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones -ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() +ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() \ No newline at end of file From e0d37bb0ded144c5db38fe30b9a747c0239f9059 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 8 Apr 2025 17:48:22 -0400 Subject: [PATCH 14/27] Removed unused variable. --- src/transformers/modeling_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 47a859e639da..e6762848a2ac 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4887,7 +4887,6 @@ def _load_pretrained_model( for shard_file in checkpoint_files ] - mismatched_keys = [] error_msgs = [] # Use multiprocessing Pool for parallel execution, off by default From 9b4165c10a136e28898f71a7f3736fe218e9aab7 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 8 Apr 2025 18:30:26 -0400 Subject: [PATCH 15/27] Fix argument packing for the parallel loader. --- src/transformers/modeling_utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e6762848a2ac..0d831d73ca37 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4862,18 +4862,15 @@ def _load_pretrained_model( # Prepare arguments for multiprocessing args_list = [ ( - state_dict, shard_file, disk_only_shard_files, + is_hqq_or_bnb, is_quantized, device_map, hf_quantizer, key_renaming_mapping, weights_only, model_to_load, - ignore_mismatched_sizes, - prefix, - loading_base_model_from_task_state_dict, expected_keys, reverse_key_renaming_mapping, disk_offload_folder, @@ -4881,8 +4878,9 @@ def _load_pretrained_model( cpu_offload_folder, cpu_offload_index, is_offloaded_safetensors, + keep_in_fp32_regex, unexpected_keys, - device_mesh, + device_mesh ) for shard_file in checkpoint_files ] From 1085461d9d0cabc7e1ac824d5343e29a57bb2d2b Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 8 Apr 2025 18:42:41 -0400 Subject: [PATCH 16/27] Fix state dict being undefined in the parallel model loader. --- src/transformers/modeling_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0d831d73ca37..589fd67f9b14 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -834,6 +834,7 @@ def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): def load_shard_file(args): ( shard_file, + state_dict, disk_only_shard_files, is_hqq_or_bnb, is_quantized, @@ -4863,6 +4864,7 @@ def _load_pretrained_model( args_list = [ ( shard_file, + state_dict, disk_only_shard_files, is_hqq_or_bnb, is_quantized, From 7ae3db6148caf419e42665cb379b8d8d359b6037 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 29 Apr 2025 15:58:33 -0400 Subject: [PATCH 17/27] Rename variables used in parallel model loading for clarity. Use get_module_from_name(). --- src/transformers/modeling_utils.py | 52 ++++++++++++------------------ 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e3ace6db745f..e9c4e116bf43 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -836,26 +836,18 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index +def resolve_state_dict_tensor_refs(model_to_load, state_dict, expected_keys): + state_dict_tensor_refs = {} -def resolve_state_dict_modules(model_to_load, state_dict, expected_keys): - state_dict_modules = {} - - for tensor_name in state_dict.keys(): - if tensor_name not in expected_keys: + for full_tensor_name in state_dict.keys(): + if full_tensor_name not in expected_keys: continue - splits = tensor_name.split(".") - module = model_to_load - for split in splits: - try: - module = getattr(module, split) - except Exception as exception: - print(exception) - pass + module, tensor_name = get_module_from_name(model_to_load, full_tensor_name) - state_dict_modules[tensor_name] = module + state_dict_tensor_refs[full_tensor_name] = getattr(module, tensor_name) - return state_dict_modules + return state_dict_tensor_refs # This function is in global scope so it's picklable for multiprocessing def load_shard_file(args): @@ -938,12 +930,12 @@ def load_shard_file(args): ) # We now figure out what in the state dict changed and store the module used for each layer, this will contain the device # information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model - state_dict_modules = resolve_state_dict_modules(model_to_load, state_dict, expected_keys) + state_dict_tensor_refs = resolve_state_dict_tensor_refs(model_to_load, state_dict, expected_keys) # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop del state_dict - return error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules + return error_msgs, disk_offload_index, cpu_offload_index, state_dict_tensor_refs def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: @@ -5050,7 +5042,7 @@ def _load_pretrained_model( error_msgs = [] # Use multiprocessing Pool for parallel execution, off by default - if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")): + if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): original_start_method = multiprocessing.get_start_method(allow_none=True) try: @@ -5063,7 +5055,7 @@ def _load_pretrained_model( num_workers = min(len(args_list), num_workers) logger.info(f"Loading model weights in parallel with {num_workers} workers...") - state_dict_modules_list = [] + state_dict_tensor_refs_list = [] with multiprocessing.Pool(processes=num_workers) as pool: # For nice tqdm bars @@ -5074,12 +5066,12 @@ def _load_pretrained_model( _error_msgs, disk_offload_index, cpu_offload_index, - state_dict_modules, + state_dict_tensor_refs, ) = result error_msgs += _error_msgs - state_dict_modules_list.append(state_dict_modules) + state_dict_tensor_refs_list.append(state_dict_tensor_refs) pbar.update(1) @@ -5088,14 +5080,12 @@ def _load_pretrained_model( # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct - for state_dict_modules in state_dict_modules_list: - for full_name, param in state_dict_modules.items(): - *module_path, attr_name = full_name.split(".") - module_path = ".".join(module_path) - module = model_to_load.get_submodule(module_path) - setattr(module, attr_name, param) - - del state_dict_modules_list + for state_dict_tensor_refs in state_dict_tensor_refs_list: + for full_tensor_name, tensor in state_dict_tensor_refs.items(): + module, tensor_name = get_module_from_name(model_to_load, full_tensor_name) + setattr(module, tensor_name, tensor) + + del state_dict_tensor_refs_list gc.collect() finally: # Restore the start method to prevent side effects for other code that may be running @@ -5107,12 +5097,12 @@ def _load_pretrained_model( args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") for args in args_list: - _error_msgs, disk_offload_index, cpu_offload_index, state_dict_modules = ( + _error_msgs, disk_offload_index, cpu_offload_index, state_dict_tensor_refs = ( load_shard_file(args) ) error_msgs += _error_msgs - del state_dict_modules + del state_dict_tensor_refs gc.collect() # Adjust offloaded weights name and save if needed From 8d04325e40199dcc32505b1223919ce915e7f89f Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 29 Apr 2025 18:27:34 -0400 Subject: [PATCH 18/27] Switch to the use of threads for parallel model loading. --- src/transformers/modeling_utils.py | 134 +++++++++-------------------- 1 file changed, 43 insertions(+), 91 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index e9c4e116bf43..b73ed49810d6 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -21,7 +21,6 @@ import inspect import itertools import json -import multiprocessing import os import re import shutil @@ -29,6 +28,7 @@ import warnings from collections import defaultdict from collections.abc import MutableMapping +from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -836,18 +836,6 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index -def resolve_state_dict_tensor_refs(model_to_load, state_dict, expected_keys): - state_dict_tensor_refs = {} - - for full_tensor_name in state_dict.keys(): - if full_tensor_name not in expected_keys: - continue - - module, tensor_name = get_module_from_name(model_to_load, full_tensor_name) - - state_dict_tensor_refs[full_tensor_name] = getattr(module, tensor_name) - - return state_dict_tensor_refs # This function is in global scope so it's picklable for multiprocessing def load_shard_file(args): @@ -871,12 +859,12 @@ def load_shard_file(args): is_offloaded_safetensors, keep_in_fp32_regex, unexpected_keys, - device_mesh - ) = args + device_mesh, + ) = args # Skip the load for shards that only contain disk-offloaded weights if shard_file in disk_only_shard_files: - return [], disk_offload_index, cpu_offload_index, {} + return [], disk_offload_index, cpu_offload_index map_location = "cpu" if ( @@ -928,14 +916,13 @@ def load_shard_file(args): unexpected_keys=unexpected_keys, device_mesh=device_mesh, ) - # We now figure out what in the state dict changed and store the module used for each layer, this will contain the device - # information we need in order to resolve all of the layers after multiprocessing which we write back to the original model_to_load meta model - state_dict_tensor_refs = resolve_state_dict_tensor_refs(model_to_load, state_dict, expected_keys) - # force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop + # force memory release to avoid having multiple state dicts in memory as shards are processed del state_dict + gc.collect() + + return error_msgs, disk_offload_index, cpu_offload_index - return error_msgs, disk_offload_index, cpu_offload_index, state_dict_tensor_refs def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: @@ -2651,9 +2638,9 @@ def tie_encoder_to_decoder_recursively( total_decoder_name="", total_encoder_name="", ): - assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), ( - f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" - ) + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" if hasattr(decoder_pointer, "weight"): assert hasattr(encoder_pointer, "weight") encoder_pointer.weight = decoder_pointer.weight @@ -2667,9 +2654,9 @@ def tie_encoder_to_decoder_recursively( encoder_modules = encoder_pointer._modules decoder_modules = decoder_pointer._modules if len(decoder_modules) > 0: - assert len(encoder_modules) > 0, ( - f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" - ) + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()} encoder_layer_pos = 0 @@ -4995,9 +4982,6 @@ def _load_pretrained_model( cpu_offload_folder = tempfile.mkdtemp() cpu_offload_index = {} - # For nice tqdm bars - if checkpoint_files is not None and len(checkpoint_files) > 1: - checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards") # To be able to iterate, even if we don't use it if the state_dict is already provided elif state_dict is not None: checkpoint_files = [""] @@ -5034,7 +5018,7 @@ def _load_pretrained_model( is_offloaded_safetensors, keep_in_fp32_regex, unexpected_keys, - device_mesh + device_mesh, ) for shard_file in checkpoint_files ] @@ -5043,68 +5027,36 @@ def _load_pretrained_model( # Use multiprocessing Pool for parallel execution, off by default if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): - original_start_method = multiprocessing.get_start_method(allow_none=True) + num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - try: - # CUDA requires the start method to be spawn, fork creates multiple copies of the cuda runtime, which throws - multiprocessing.set_start_method("spawn", force=True) - - num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - - # Do not spawn anymore workers than you need - num_workers = min(len(args_list), num_workers) - - logger.info(f"Loading model weights in parallel with {num_workers} workers...") - state_dict_tensor_refs_list = [] - - with multiprocessing.Pool(processes=num_workers) as pool: - # For nice tqdm bars - with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: - # NOTE order does not matter, layers that changed per shard are unique and can be reassigned to the orignal meta model - for result in pool.imap_unordered(load_shard_file, args_list): - ( - _error_msgs, - disk_offload_index, - cpu_offload_index, - state_dict_tensor_refs, - ) = result - - error_msgs += _error_msgs - - state_dict_tensor_refs_list.append(state_dict_tensor_refs) - - pbar.update(1) - - # We now update each layer of the meta model with the tensor module refs that were set to specific devices in the copy of the meta model for each worker - # We are transferring that state into the orginal ref (model_to_load) here - # This is required because model_to_load is pickled when using multiprocessing, which means the ref to model_to_load is different for each worker, so you only get some of the state with respect to the loaded tensors - # You could in theory return each worker's copy of the model and use .named_parameters(), and .named_buffers(), but this appears to be more robust - # in that all you have to care about are the names of the layers in the state dict, as long as the logic that lead to the creation of the state_dict is correct, this will also be correct - for state_dict_tensor_refs in state_dict_tensor_refs_list: - for full_tensor_name, tensor in state_dict_tensor_refs.items(): - module, tensor_name = get_module_from_name(model_to_load, full_tensor_name) - setattr(module, tensor_name, tensor) - - del state_dict_tensor_refs_list - gc.collect() - finally: - # Restore the start method to prevent side effects for other code that may be running - multiprocessing.set_start_method(original_start_method, force=True) + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + # with multiprocessing.Pool(processes=num_workers) as pool: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_shard_file, arg) for arg in args_list] + for future in as_completed(futures): + result = future.result() + ( + _error_msgs, + disk_offload_index, + cpu_offload_index, + ) = result + error_msgs += _error_msgs + + pbar.update(1) else: if len(args_list) > 1: - # For nice tqdm bars args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") for args in args_list: - _error_msgs, disk_offload_index, cpu_offload_index, state_dict_tensor_refs = ( - load_shard_file(args) - ) + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args) error_msgs += _error_msgs - del state_dict_tensor_refs - gc.collect() - # Adjust offloaded weights name and save if needed if disk_offload_index is not None and len(disk_offload_index) > 0: if loading_task_model_from_base_state_dict: @@ -5648,9 +5600,9 @@ def forward( Returns: `torch.FloatTensor`: The end logits for SQuAD. """ - assert start_states is not None or start_positions is not None, ( - "One of start_states, start_positions should be not None" - ) + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" if start_positions is not None: slen, hsz = hidden_states.shape[-2:] start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) @@ -5720,9 +5672,9 @@ def forward( """ # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. hsz = hidden_states.shape[-1] - assert start_states is not None or start_positions is not None, ( - "One of start_states, start_positions should be not None" - ) + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" if start_positions is not None: start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) @@ -6177,4 +6129,4 @@ def valid_keys(self) -> List[str]: # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones -ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() \ No newline at end of file +ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface() From 674ec3730f9ef2e9034978fc0338bba86a8e2d45 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Tue, 29 Apr 2025 18:39:18 -0400 Subject: [PATCH 19/27] Update docs for parallel loading. --- docs/source/en/reference/environment_variables.md | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index fc09a98b5aec..6716e4e2376d 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -31,7 +31,6 @@ import os os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" -import multiprocessing from transformers import pipeline model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") @@ -39,9 +38,9 @@ model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="a ## HF_PARALLEL_LOADING_WORKERS -Determines how many child processes should be used when parallel loading is enabled. Default is `8`. +Determines how many threads should be used when parallel loading is enabled. Default is `8`. -If the number of files that are being loaded is less than the number of child processes specified, the number that is actually spawned will be equal to the number of files. +If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files. e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned. @@ -53,7 +52,6 @@ import os os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true" os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4" -import multiprocessing from transformers import pipeline model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto") From b8a1470f6efb03f4cb4f2cc67cce54eb9c7bfeb1 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 30 Apr 2025 17:59:43 -0400 Subject: [PATCH 20/27] Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADING. Prefer simple casting. --- src/transformers/modeling_utils.py | 7 +++++-- src/transformers/utils/import_utils.py | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b73ed49810d6..a899237f094b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -125,6 +125,7 @@ is_sagemaker_mp_enabled, is_torch_fx_proxy, is_torchdynamo_compiling, + is_true, ) from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod @@ -5026,8 +5027,10 @@ def _load_pretrained_model( error_msgs = [] # Use multiprocessing Pool for parallel execution, off by default - if json.loads(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): - num_workers = json.loads(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) + ENV_VARS_TRUE_VALUES + + if is_true(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): + num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) # Do not spawn anymore workers than you need num_workers = min(len(args_list), num_workers) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9654d5d1fff3..1fb0a94beaf4 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -86,6 +86,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) +def is_true(value: Optional[str]) -> bool: + if value is None: + return False + return value.upper() in ENV_VARS_TRUE_VALUES + USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() From efb6605b6bd0db95bf80d6aba98d81808ee6ba93 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 30 Apr 2025 18:14:47 -0400 Subject: [PATCH 21/27] Move parallelized shard loading into its own function. --- src/transformers/modeling_utils.py | 56 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index a899237f094b..3b40e6580b7d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -838,7 +838,6 @@ def _load_state_dict_into_meta_model( return disk_offload_index, cpu_offload_index -# This function is in global scope so it's picklable for multiprocessing def load_shard_file(args): ( shard_file, @@ -924,6 +923,32 @@ def load_shard_file(args): return error_msgs, disk_offload_index, cpu_offload_index +def load_shard_files_with_threadpool(args_list): + num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) + + # Do not spawn anymore workers than you need + num_workers = min(len(args_list), num_workers) + + logger.info(f"Loading model weights in parallel with {num_workers} workers...") + + error_msgs = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: + futures = [executor.submit(load_shard_file, arg) for arg in args_list] + for future in as_completed(futures): + result = future.result() + ( + _error_msgs, + disk_offload_index, + cpu_offload_index, + ) = result + + error_msgs += _error_msgs + + pbar.update(1) + + return error_msgs, disk_offload_index, cpu_offload_index def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: if variant is not None: @@ -4997,7 +5022,7 @@ def _load_pretrained_model( expanded_device_map = expand_device_map(device_map, expected_keys) caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer) - # Prepare arguments for multiprocessing + # Prepare and compatabilize arguments for serial and parallel shard loading args_list = [ ( shard_file, @@ -5026,32 +5051,9 @@ def _load_pretrained_model( error_msgs = [] - # Use multiprocessing Pool for parallel execution, off by default - ENV_VARS_TRUE_VALUES - if is_true(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): - num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8")) - - # Do not spawn anymore workers than you need - num_workers = min(len(args_list), num_workers) - - logger.info(f"Loading model weights in parallel with {num_workers} workers...") - - # with multiprocessing.Pool(processes=num_workers) as pool: - with ThreadPoolExecutor(max_workers=num_workers) as executor: - with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar: - futures = [executor.submit(load_shard_file, arg) for arg in args_list] - for future in as_completed(futures): - result = future.result() - ( - _error_msgs, - disk_offload_index, - cpu_offload_index, - ) = result - - error_msgs += _error_msgs - - pbar.update(1) + _error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list) + error_msgs += _error_msgs else: if len(args_list) > 1: args_list = logging.tqdm(args_list, desc="Loading checkpoint shards") From c66daef802aea78abd1643df228a7644fb0d0805 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Thu, 1 May 2025 13:07:33 -0400 Subject: [PATCH 22/27] Remove use of is_true(). Favor checking env var true values for HF_ENABLE_PARALLEL_LOADING. --- src/transformers/modeling_utils.py | 3 +-- src/transformers/utils/import_utils.py | 5 ----- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3b40e6580b7d..432d5b661648 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -125,7 +125,6 @@ is_sagemaker_mp_enabled, is_torch_fx_proxy, is_torchdynamo_compiling, - is_true, ) from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod @@ -5051,7 +5050,7 @@ def _load_pretrained_model( error_msgs = [] - if is_true(os.environ.get("HF_ENABLE_PARALLEL_LOADING", "false")) and not is_deepspeed_zero3_enabled(): + if os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES and not is_deepspeed_zero3_enabled(): _error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list) error_msgs += _error_msgs else: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 1fb0a94beaf4..9654d5d1fff3 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -86,11 +86,6 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) -def is_true(value: Optional[str]) -> bool: - if value is None: - return False - return value.upper() in ENV_VARS_TRUE_VALUES - USE_TF = os.environ.get("USE_TF", "AUTO").upper() USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() From 4566c5ccf98e04224cd6215c0dd5cf0b7a2f9e33 Mon Sep 17 00:00:00 2001 From: Aaron Vogler Date: Wed, 14 May 2025 20:17:48 -0400 Subject: [PATCH 23/27] Update copyright to 2025 in readme for paralell model loading. --- docs/source/en/reference/environment_variables.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/en/reference/environment_variables.md b/docs/source/en/reference/environment_variables.md index 6716e4e2376d..fc20c08f9e6a 100644 --- a/docs/source/en/reference/environment_variables.md +++ b/docs/source/en/reference/environment_variables.md @@ -1,4 +1,4 @@ -