Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
8fb9b18
Get parallel loader working. Include tests.
inf3rnus Mar 19, 2025
27f36f2
Update the tests for parallel loading
inf3rnus Mar 19, 2025
7e5ecd8
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus Mar 19, 2025
e7c3ea5
Rename env variables.
inf3rnus Mar 19, 2025
7599fe2
Add docs for parallel model weight loading.
inf3rnus Mar 19, 2025
065e102
Touch up parallel model loading docs.
inf3rnus Mar 19, 2025
d31594a
Touch up parallel model loading docs again.
inf3rnus Mar 19, 2025
33b3e0f
Edit comment in test_modeling_utils_parallel_loading.py
inf3rnus Mar 19, 2025
3fb6b65
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus Mar 19, 2025
0e22c04
Make sure HF_PARALLEL_LOADING_WORKERS is spelled correctly in modelin…
inf3rnus Mar 19, 2025
904bdaf
Correct times for parallelized loading, previous times were for a "ho…
inf3rnus Mar 21, 2025
7e37ba4
Update parallel model loading so the spawn method is encapsulated. DR…
inf3rnus Mar 24, 2025
a203f6a
Update docs on model loading parallelism so that details on setting t…
inf3rnus Mar 24, 2025
14e9eef
Fix style on model loading parallelism changes.
inf3rnus Mar 24, 2025
fe1fc0c
Merge remote-tracking branch 'upstream/main' into 03-18-25-parallel-m…
inf3rnus Apr 8, 2025
d5637e8
Merge latest version of master's modeling_utils.
inf3rnus Apr 8, 2025
e0d37bb
Removed unused variable.
inf3rnus Apr 8, 2025
9b4165c
Fix argument packing for the parallel loader.
inf3rnus Apr 8, 2025
1085461
Fix state dict being undefined in the parallel model loader.
inf3rnus Apr 8, 2025
82ab2ec
Merge main.
inf3rnus Apr 29, 2025
7ae3db6
Rename variables used in parallel model loading for clarity. Use get_…
inf3rnus Apr 29, 2025
8d04325
Switch to the use of threads for parallel model loading.
inf3rnus Apr 29, 2025
674ec37
Update docs for parallel loading.
inf3rnus Apr 29, 2025
b8a1470
Remove the use of json.loads when evaluating HF_ENABLE_PARALLEL_LOADI…
inf3rnus Apr 30, 2025
efb6605
Move parallelized shard loading into its own function.
inf3rnus Apr 30, 2025
c66daef
Remove use of is_true(). Favor checking env var true values for HF_EN…
inf3rnus May 1, 2025
4566c5c
Update copyright to 2025 in readme for paralell model loading.
inf3rnus May 15, 2025
610c5e3
Remove garbage collection line in load_shard_file, implicit garbage c…
inf3rnus May 15, 2025
a9cb54b
Run formatter on modeling_utils.py
inf3rnus May 15, 2025
fc76fbb
Merge branch 'main' into 03-18-25-parallel-model-loading
inf3rnus May 15, 2025
16f3751
Apply style fixes
github-actions[bot] May 22, 2025
cd0f42e
Merge main.
inf3rnus May 22, 2025
3b9f458
Delete tests/utils/test_modeling_utils_parallel_loading.py
inf3rnus May 22, 2025
b6bf421
Merge branch 'main' into 03-18-25-parallel-model-loading
Cyrilvallez May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1107,4 +1107,9 @@
- local: internal/time_series_utils
title: Utilities for Time Series
title: Internal helpers
- sections:
- local: reference/environment_variables
title: Environment Variables
title: Reference
title: API

58 changes: 58 additions & 0 deletions docs/source/en/reference/environment_variables.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Environment Variables

## HF_ENABLE_PARALLEL_LOADING

By default this is disabled. Enables the loading of torch and safetensor based weights to be loaded in parallel. Can decrease the time to load large models significantly, often times producing speed ups around ~50%.

Can be set to a string equal to `"false"` or `"true"`. e.g. `os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"`.

e.g. `facebook/opt-30b` on an AWS EC2 g4dn.metal instance can be made to load in ~30s with this enabled vs ~55s without it.

Profile before committing to using this environment variable, this will not produce speed ups for smaller models.

```py
import os

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"

from transformers import pipeline

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```

## HF_PARALLEL_LOADING_WORKERS

Determines how many threads should be used when parallel loading is enabled. Default is `8`.

If the number of files that are being loaded is less than the number of threads specified, the number that is actually spawned will be equal to the number of files.

e.g. If you specify 8 workers, and there are only 2 files, only 2 workers will be spawned.

Tune as you see fit.

```py
import os

os.environ["HF_ENABLE_PARALLEL_LOADING"] = "true"
os.environ["HF_PARALLEL_LOADING_WORKERS"] = "4"

from transformers import pipeline

model = pipeline(task="text-generation", model="facebook/opt-30b", device_map="auto")
```
231 changes: 161 additions & 70 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from collections import defaultdict
from collections.abc import MutableMapping
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -836,6 +837,118 @@ def _load_state_dict_into_meta_model(
return disk_offload_index, cpu_offload_index


def load_shard_file(args):
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
) = args

# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
return [], disk_offload_index, cpu_offload_index

map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])

# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)

# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}

error_msgs = []

if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)

# force memory release to avoid having multiple state dicts in memory as shards are processed
del state_dict
gc.collect()

return error_msgs, disk_offload_index, cpu_offload_index

def load_shard_files_with_threadpool(args_list):
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))

# Do not spawn anymore workers than you need
num_workers = min(len(args_list), num_workers)

logger.info(f"Loading model weights in parallel with {num_workers} workers...")

error_msgs = []

with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_shard_file, arg) for arg in args_list]
for future in as_completed(futures):
result = future.result()
(
_error_msgs,
disk_offload_index,
cpu_offload_index,
) = result

error_msgs += _error_msgs

pbar.update(1)

return error_msgs, disk_offload_index, cpu_offload_index

def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
path, name = weights_name.rsplit(".", 1)
Expand Down Expand Up @@ -2550,9 +2663,9 @@ def tie_encoder_to_decoder_recursively(
total_decoder_name="",
total_encoder_name="",
):
assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), (
f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
)
assert isinstance(decoder_pointer, nn.Module) and isinstance(
encoder_pointer, nn.Module
), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
if hasattr(decoder_pointer, "weight"):
assert hasattr(encoder_pointer, "weight")
encoder_pointer.weight = decoder_pointer.weight
Expand All @@ -2566,9 +2679,9 @@ def tie_encoder_to_decoder_recursively(
encoder_modules = encoder_pointer._modules
decoder_modules = decoder_pointer._modules
if len(decoder_modules) > 0:
assert len(encoder_modules) > 0, (
f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
)
assert (
len(encoder_modules) > 0
), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"

all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()}
encoder_layer_pos = 0
Expand Down Expand Up @@ -4894,9 +5007,6 @@ def _load_pretrained_model(
cpu_offload_folder = tempfile.mkdtemp()
cpu_offload_index = {}

# For nice tqdm bars
if checkpoint_files is not None and len(checkpoint_files) > 1:
checkpoint_files = logging.tqdm(checkpoint_files, desc="Loading checkpoint shards")
# To be able to iterate, even if we don't use it if the state_dict is already provided
elif state_dict is not None:
checkpoint_files = [""]
Expand All @@ -4911,64 +5021,45 @@ def _load_pretrained_model(
expanded_device_map = expand_device_map(device_map, expected_keys)
caching_allocator_warmup(model_to_load, expanded_device_map, hf_quantizer)

error_msgs = []
# Iterate on all the shards to load the weights
for shard_file in checkpoint_files:
# Skip the load for shards that only contain disk-offloaded weights
if shard_file in disk_only_shard_files:
continue

map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
elif (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and (
hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
or isinstance(hf_quantizer.quantization_config.quant_type, Int4WeightOnlyConfig)
)
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])

# If shard_file is "", we use the existing state_dict instead of loading it
if shard_file != "":
state_dict = load_state_dict(
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
)
# Prepare and compatabilize arguments for serial and parallel shard loading
args_list = [
(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_quantized,
device_map,
hf_quantizer,
key_renaming_mapping,
weights_only,
model_to_load,
expected_keys,
reverse_key_renaming_mapping,
disk_offload_folder,
disk_offload_index,
cpu_offload_folder,
cpu_offload_index,
is_offloaded_safetensors,
keep_in_fp32_regex,
unexpected_keys,
device_mesh,
)
for shard_file in checkpoint_files
]

# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
error_msgs = []

if is_deepspeed_zero3_enabled() and not is_quantized:
error_msgs += _load_state_dict_into_zero3_model(model_to_load, state_dict)
# Skip it with fsdp on ranks other than 0
elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
disk_offload_index, cpu_offload_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
shard_file,
expected_keys,
reverse_key_renaming_mapping,
device_map=device_map,
disk_offload_folder=disk_offload_folder,
disk_offload_index=disk_offload_index,
cpu_offload_folder=cpu_offload_folder,
cpu_offload_index=cpu_offload_index,
hf_quantizer=hf_quantizer,
is_safetensors=is_offloaded_safetensors,
keep_in_fp32_regex=keep_in_fp32_regex,
unexpected_keys=unexpected_keys,
device_mesh=device_mesh,
)
if os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES and not is_deepspeed_zero3_enabled():
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_files_with_threadpool(args_list)
error_msgs += _error_msgs
else:
if len(args_list) > 1:
args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")

# force memory release if loading multiple shards, to avoid having 2 state dicts in memory in next loop
del state_dict
for args in args_list:
_error_msgs, disk_offload_index, cpu_offload_index = load_shard_file(args)
error_msgs += _error_msgs

# Adjust offloaded weights name and save if needed
if disk_offload_index is not None and len(disk_offload_index) > 0:
Expand Down Expand Up @@ -5513,9 +5604,9 @@ def forward(
Returns:
`torch.FloatTensor`: The end logits for SQuAD.
"""
assert start_states is not None or start_positions is not None, (
"One of start_states, start_positions should be not None"
)
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
slen, hsz = hidden_states.shape[-2:]
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
Expand Down Expand Up @@ -5585,9 +5676,9 @@ def forward(
"""
# No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
hsz = hidden_states.shape[-1]
assert start_states is not None or start_positions is not None, (
"One of start_states, start_positions should be not None"
)
assert (
start_states is not None or start_positions is not None
), "One of start_states, start_positions should be not None"
if start_positions is not None:
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
Expand Down
Loading