Skip to content

Commit ae2561b

Browse files
committed
review feedback.
1 parent 36c86d2 commit ae2561b

File tree

4 files changed

+5
-3
lines changed

4 files changed

+5
-3
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,7 @@ def _load_shard_files_with_threadpool(
407407
low_cpu_mem_usage=False,
408408
):
409409
# Do not spawn anymore workers than you need
410-
num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", str(DEFAULT_HF_PARALLEL_LOADING_WORKERS)))
411-
num_workers = min(len(shard_files), num_workers)
410+
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
412411

413412
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
414413

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
CONFIG_NAME,
4545
ENV_VARS_TRUE_VALUES,
4646
FLAX_WEIGHTS_NAME,
47+
HF_PARALLEL_LOADING_FLAG,
4748
SAFE_WEIGHTS_INDEX_NAME,
4849
SAFETENSORS_WEIGHTS_NAME,
4950
WEIGHTS_INDEX_NAME,
@@ -961,7 +962,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
961962
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
962963
disable_mmap = kwargs.pop("disable_mmap", False)
963964

964-
is_parallel_loading_enabled = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
965+
is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES
965966
if is_parallel_loading_enabled and not low_cpu_mem_usage:
966967
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
967968

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
FLAX_WEIGHTS_NAME,
2727
GGUF_FILE_EXTENSION,
2828
HF_MODULES_CACHE,
29+
HF_PARALLEL_LOADING_FLAG,
2930
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3031
MIN_PEFT_VERSION,
3132
ONNX_EXTERNAL_WEIGHTS_NAME,

src/diffusers/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
47+
HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING"
4748

4849
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
4950
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

0 commit comments

Comments
 (0)