Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
low_cpu_mem_usage=low_cpu_mem_usage,
)

tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
tqdm_kwargs["disable"] = True

with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
with logging.tqdm(**tqdm_kwargs) as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()
Expand Down
12 changes: 6 additions & 6 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@
is_torch_version,
logging,
)
from ..utils.hub_utils import (
PushToHubMixin,
load_or_create_model_card,
populate_model_card,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
Expand Down Expand Up @@ -1672,7 +1669,10 @@ def _load_pretrained_model(
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
shard_tqdm_kwargs["disable"] = True
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)

for shard_file in shard_files:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
Expand Down
15 changes: 12 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
logging,
numpy_to_pil,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module

Expand Down Expand Up @@ -982,7 +983,11 @@ def load_module(name, value):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
if not is_torch_dist_rank_zero():
logging_tqdm_kwargs["disable"] = True

for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
# 7.1 device_map shenanigans
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
Expand Down Expand Up @@ -1908,10 +1913,14 @@ def progress_bar(self, iterable=None, total=None):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)

progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
progress_bar_config["disable"] = not is_torch_dist_rank_zero()

if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
return tqdm(iterable, **progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
return tqdm(total=total, **progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

Expand Down
36 changes: 36 additions & 0 deletions src/diffusers/utils/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


try:
import torch
except ImportError:
torch = None
Comment on lines +16 to +19
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To not introduce circular import problem.



def is_torch_dist_rank_zero() -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this robust to different distributed setups such as accelerate, pure deepspeed, etc.? As a concrete example, I'm thinking about the case where diffusers models are used in distributed training and whether these changes would work as expected in that case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can control external dependencies much i.e., the logs from transformers, for example.

if torch is None:
return True

dist_module = getattr(torch, "distributed", None)
if dist_module is None or not dist_module.is_available():
return True

if not dist_module.is_initialized():
return True

try:
return dist_module.get_rank() == 0
except (RuntimeError, ValueError):
return True
24 changes: 23 additions & 1 deletion src/diffusers/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from tqdm import auto as tqdm_lib

from .distributed_utils import is_torch_dist_rank_zero


_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
Expand All @@ -47,6 +49,23 @@
_default_log_level = logging.WARNING

_tqdm_active = True
_rank_zero_filter = None


class _RankZeroFilter(logging.Filter):
def filter(self, record):
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG


def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
global _rank_zero_filter

if _rank_zero_filter is None:
_rank_zero_filter = _RankZeroFilter()

if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
logger.addFilter(_rank_zero_filter)


def _get_default_logging_level() -> int:
Expand Down Expand Up @@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
_ensure_rank_zero_filter(library_root_logger)


def _reset_library_root_logger() -> None:
Expand Down Expand Up @@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
name = _get_library_name()

_configure_library_root_logger()
return logging.getLogger(name)
logger = logging.getLogger(name)
_ensure_rank_zero_filter(logger)
return logger


def get_verbosity() -> int:
Expand Down
Loading