Skip to content

Commit 736971c

Browse files
authored
Merge branch 'main' into cuda-device-map-pipe
2 parents 5e6f142 + 421ee07 commit 736971c

File tree

4 files changed

+28
-5
lines changed

4 files changed

+28
-5
lines changed

docs/source/en/using-diffusers/loading.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,30 @@ print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
112112

113113
If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
114114

115+
### Parallel loading
116+
117+
Large models are often [sharded](../training/distributed_inference#model-sharding) into smaller files so that they are easier to load. Diffusers supports loading shards in parallel to speed up the loading process.
118+
119+
Set the environment variables below to enable parallel loading.
120+
121+
- Set `HF_ENABLE_PARALLEL_LOADING` to `"YES"` to enable parallel loading of shards.
122+
- Set `HF_PARALLEL_LOADING_WORKERS` to configure the number of parallel threads to use when loading shards. More workers loads a model faster but uses more memory.
123+
124+
The `device_map` argument should be set to `"cuda"` to pre-allocate a large chunk of memory based on the model size. This substantially reduces model load time because warming up the memory allocator now avoids many smaller calls to the allocator later.
125+
126+
```py
127+
import os
128+
import torch
129+
from diffusers import DiffusionPipeline
130+
131+
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "YES"
132+
pipeline = DiffusionPipeline.from_pretrained(
133+
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
134+
torch_dtype=torch.bfloat16,
135+
device_map="cuda"
136+
)
137+
```
138+
115139
### Local pipeline
116140

117141
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@
4242
from ..quantizers.quantization_config import QuantizationMethod
4343
from ..utils import (
4444
CONFIG_NAME,
45-
ENV_VARS_TRUE_VALUES,
4645
FLAX_WEIGHTS_NAME,
47-
HF_PARALLEL_LOADING_FLAG,
46+
HF_ENABLE_PARALLEL_LOADING,
4847
SAFE_WEIGHTS_INDEX_NAME,
4948
SAFETENSORS_WEIGHTS_NAME,
5049
WEIGHTS_INDEX_NAME,
@@ -962,7 +961,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
962961
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
963962
disable_mmap = kwargs.pop("disable_mmap", False)
964963

965-
is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES
964+
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
966965
if is_parallel_loading_enabled and not low_cpu_mem_usage:
967966
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
968967

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
DIFFUSERS_DYNAMIC_MODULE_NAME,
2626
FLAX_WEIGHTS_NAME,
2727
GGUF_FILE_EXTENSION,
28+
HF_ENABLE_PARALLEL_LOADING,
2829
HF_MODULES_CACHE,
29-
HF_PARALLEL_LOADING_FLAG,
3030
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3131
MIN_PEFT_VERSION,
3232
ONNX_EXTERNAL_WEIGHTS_NAME,

src/diffusers/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
47-
HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING"
47+
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848

4949
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5050
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

0 commit comments

Comments
 (0)