Skip to content

Commit 123506e

Browse files
authored
make parallel loading flag a part of constants. (#12137)
1 parent 8c48ec0 commit 123506e

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@
4242
from ..quantizers.quantization_config import QuantizationMethod
4343
from ..utils import (
4444
CONFIG_NAME,
45-
ENV_VARS_TRUE_VALUES,
4645
FLAX_WEIGHTS_NAME,
47-
HF_PARALLEL_LOADING_FLAG,
46+
HF_ENABLE_PARALLEL_LOADING,
4847
SAFE_WEIGHTS_INDEX_NAME,
4948
SAFETENSORS_WEIGHTS_NAME,
5049
WEIGHTS_INDEX_NAME,
@@ -962,7 +961,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
962961
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
963962
disable_mmap = kwargs.pop("disable_mmap", False)
964963

965-
is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES
964+
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
966965
if is_parallel_loading_enabled and not low_cpu_mem_usage:
967966
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
968967

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
DIFFUSERS_DYNAMIC_MODULE_NAME,
2626
FLAX_WEIGHTS_NAME,
2727
GGUF_FILE_EXTENSION,
28+
HF_ENABLE_PARALLEL_LOADING,
2829
HF_MODULES_CACHE,
29-
HF_PARALLEL_LOADING_FLAG,
3030
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
3131
MIN_PEFT_VERSION,
3232
ONNX_EXTERNAL_WEIGHTS_NAME,

src/diffusers/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
4545
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
4646
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
47-
HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING"
47+
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
4848

4949
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
5050
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

0 commit comments

Comments
 (0)