diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..08b3f0234f82 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -47,6 +47,7 @@ is_torch_version, logging, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero logger = logging.get_logger(__name__) @@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool( low_cpu_mem_usage=low_cpu_mem_usage, ) + tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + tqdm_kwargs["disable"] = True + with ThreadPoolExecutor(max_workers=num_workers) as executor: - with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar: + with logging.tqdm(**tqdm_kwargs) as pbar: futures = [executor.submit(load_one, shard_file) for shard_file in shard_files] for future in as_completed(futures): result = future.result() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41da95d3a2a2..ba5b93605f50 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -59,11 +59,8 @@ is_torch_version, logging, ) -from ..utils.hub_utils import ( - PushToHubMixin, - load_or_create_model_card, - populate_model_card, -) +from ..utils.distributed_utils import is_torch_dist_rank_zero +from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card from ..utils.torch_utils import empty_device_cache from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig from .model_loading_utils import ( @@ -1672,7 +1669,10 @@ def _load_pretrained_model( else: shard_files = resolved_model_file if len(resolved_model_file) > 1: - shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + shard_tqdm_kwargs["disable"] = True + shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs) for shard_file in shard_files: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..5c4ac8a6554a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,6 +67,7 @@ logging, numpy_to_pil, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module @@ -982,7 +983,11 @@ def load_module(name, value): # 7. Load each module in the pipeline current_device_map = None _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) - for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): + logging_tqdm_kwargs = {"desc": "Loading pipeline components..."} + if not is_torch_dist_rank_zero(): + logging_tqdm_kwargs["disable"] = True + + for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs): # 7.1 device_map shenanigans if final_device_map is not None: if isinstance(final_device_map, dict) and len(final_device_map) > 0: @@ -1908,10 +1913,14 @@ def progress_bar(self, iterable=None, total=None): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) + progress_bar_config = dict(self._progress_bar_config) + if "disable" not in progress_bar_config: + progress_bar_config["disable"] = not is_torch_dist_rank_zero() + if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) + return tqdm(iterable, **progress_bar_config) elif total is not None: - return tqdm(total=total, **self._progress_bar_config) + return tqdm(total=total, **progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py new file mode 100644 index 000000000000..239b7b26200d --- /dev/null +++ b/src/diffusers/utils/distributed_utils.py @@ -0,0 +1,36 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import torch +except ImportError: + torch = None + + +def is_torch_dist_rank_zero() -> bool: + if torch is None: + return True + + dist_module = getattr(torch, "distributed", None) + if dist_module is None or not dist_module.is_available(): + return True + + if not dist_module.is_initialized(): + return True + + try: + return dist_module.get_rank() == 0 + except (RuntimeError, ValueError): + return True diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 2ad6d3a47607..80e108e4a6ff 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -32,6 +32,8 @@ from tqdm import auto as tqdm_lib +from .distributed_utils import is_torch_dist_rank_zero + _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None @@ -47,6 +49,23 @@ _default_log_level = logging.WARNING _tqdm_active = True +_rank_zero_filter = None + + +class _RankZeroFilter(logging.Filter): + def filter(self, record): + # Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting. + return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG + + +def _ensure_rank_zero_filter(logger: logging.Logger) -> None: + global _rank_zero_filter + + if _rank_zero_filter is None: + _rank_zero_filter = _RankZeroFilter() + + if not any(isinstance(f, _RankZeroFilter) for f in logger.filters): + logger.addFilter(_rank_zero_filter) def _get_default_logging_level() -> int: @@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None: library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False + _ensure_rank_zero_filter(library_root_logger) def _reset_library_root_logger() -> None: @@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: name = _get_library_name() _configure_library_root_logger() - return logging.getLogger(name) + logger = logging.getLogger(name) + _ensure_rank_zero_filter(logger) + return logger def get_verbosity() -> int: