From f207108643a2c44192986284cd87c0f5337d2bbf Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Dec 2025 18:13:38 +0530 Subject: [PATCH 1/7] disable progressbar in distributed. --- src/diffusers/pipelines/pipeline_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 392d5fb3feb4..8a67e3d09b54 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1908,16 +1908,28 @@ def progress_bar(self, iterable=None, total=None): f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." ) + progress_bar_config = dict(self._progress_bar_config) + if "disable" not in progress_bar_config: + progress_bar_config["disable"] = self._progress_bar_disabled_for_rank() + if iterable is not None: - return tqdm(iterable, **self._progress_bar_config) + return tqdm(iterable, **progress_bar_config) elif total is not None: - return tqdm(total=total, **self._progress_bar_config) + return tqdm(total=total, **progress_bar_config) else: raise ValueError("Either `total` or `iterable` has to be defined.") def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs + def _progress_bar_disabled_for_rank(self): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + try: + return torch.distributed.get_rank() != 0 + except (RuntimeError, ValueError): + return False + return False + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this From 97e380573c5c2ade17b98d1b2b8b35ecb22b2a5a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Dec 2025 18:37:52 +0530 Subject: [PATCH 2/7] up --- src/diffusers/pipelines/pipeline_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8a67e3d09b54..6f37208227db 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -982,7 +982,11 @@ def load_module(name, value): # 7. Load each module in the pipeline current_device_map = None _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) - for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): + logging_tqdm_kwargs = {"desc": "Loading pipeline components..."} + if cls._progress_bar_disabled_for_rank(): + logging_tqdm_kwargs["disable"] = True + + for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs): # 7.1 device_map shenanigans if final_device_map is not None: if isinstance(final_device_map, dict) and len(final_device_map) > 0: @@ -1922,7 +1926,8 @@ def progress_bar(self, iterable=None, total=None): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs - def _progress_bar_disabled_for_rank(self): + @staticmethod + def _progress_bar_disabled_for_rank(): if torch.distributed.is_available() and torch.distributed.is_initialized(): try: return torch.distributed.get_rank() != 0 From 8df3fbcc143610483a43aec48df30a26b3c025bc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Dec 2025 18:47:26 +0530 Subject: [PATCH 3/7] up --- src/diffusers/models/modeling_utils.py | 7 +++++-- src/diffusers/pipelines/pipeline_utils.py | 15 +++------------ src/diffusers/utils/torch_utils.py | 17 +++++++++++++++++ 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41da95d3a2a2..567fb424d4de 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -64,7 +64,7 @@ load_or_create_model_card, populate_model_card, ) -from ..utils.torch_utils import empty_device_cache +from ..utils.torch_utils import empty_device_cache, is_torch_dist_rank_zero from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig from .model_loading_utils import ( _caching_allocator_warmup, @@ -1672,7 +1672,10 @@ def _load_pretrained_model( else: shard_files = resolved_model_file if len(resolved_model_file) > 1: - shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + shard_tqdm_kwargs["disable"] = True + shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs) for shard_file in shard_files: offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6f37208227db..de1d44df4d3c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -68,7 +68,7 @@ numpy_to_pil, ) from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module +from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module, is_torch_dist_rank_zero if is_torch_npu_available(): @@ -983,7 +983,7 @@ def load_module(name, value): current_device_map = None _maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config) logging_tqdm_kwargs = {"desc": "Loading pipeline components..."} - if cls._progress_bar_disabled_for_rank(): + if not is_torch_dist_rank_zero(): logging_tqdm_kwargs["disable"] = True for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs): @@ -1914,7 +1914,7 @@ def progress_bar(self, iterable=None, total=None): progress_bar_config = dict(self._progress_bar_config) if "disable" not in progress_bar_config: - progress_bar_config["disable"] = self._progress_bar_disabled_for_rank() + progress_bar_config["disable"] = not is_torch_dist_rank_zero() if iterable is not None: return tqdm(iterable, **progress_bar_config) @@ -1926,15 +1926,6 @@ def progress_bar(self, iterable=None, total=None): def set_progress_bar_config(self, **kwargs): self._progress_bar_config = kwargs - @staticmethod - def _progress_bar_disabled_for_rank(): - if torch.distributed.is_available() and torch.distributed.is_initialized(): - try: - return torch.distributed.get_rank() != 0 - except (RuntimeError, ValueError): - return False - return False - def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). When this diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3b66fdadbef8..b8a61b9cdb1a 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -143,6 +143,23 @@ def backend_supports_training(device: str): return BACKEND_SUPPORTS_TRAINING[device] +def is_torch_dist_rank_zero() -> bool: + if not is_torch_available(): + return True + + dist_module = getattr(torch, "distributed", None) + if dist_module is None or not dist_module.is_available(): + return True + + if not dist_module.is_initialized(): + return True + + try: + return dist_module.get_rank() == 0 + except (RuntimeError, ValueError): + return True + + def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, From 94d671d03a2cdecac1a66285480a0b0af2f85cbc Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Dec 2025 19:01:22 +0530 Subject: [PATCH 4/7] up --- src/diffusers/models/model_loading_utils.py | 7 +++- src/diffusers/models/modeling_utils.py | 9 ++---- src/diffusers/pipelines/pipeline_utils.py | 3 +- src/diffusers/utils/distributed_utils.py | 36 +++++++++++++++++++++ src/diffusers/utils/logging.py | 23 ++++++++++++- src/diffusers/utils/torch_utils.py | 17 ---------- 6 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 src/diffusers/utils/distributed_utils.py diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 8b48ba6b4873..08b3f0234f82 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -47,6 +47,7 @@ is_torch_version, logging, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero logger = logging.get_logger(__name__) @@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool( low_cpu_mem_usage=low_cpu_mem_usage, ) + tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"} + if not is_torch_dist_rank_zero(): + tqdm_kwargs["disable"] = True + with ThreadPoolExecutor(max_workers=num_workers) as executor: - with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar: + with logging.tqdm(**tqdm_kwargs) as pbar: futures = [executor.submit(load_one, shard_file) for shard_file in shard_files] for future in as_completed(futures): result = future.result() diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 567fb424d4de..ba5b93605f50 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -59,12 +59,9 @@ is_torch_version, logging, ) -from ..utils.hub_utils import ( - PushToHubMixin, - load_or_create_model_card, - populate_model_card, -) -from ..utils.torch_utils import empty_device_cache, is_torch_dist_rank_zero +from ..utils.distributed_utils import is_torch_dist_rank_zero +from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from ..utils.torch_utils import empty_device_cache from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig from .model_loading_utils import ( _caching_allocator_warmup, diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index de1d44df4d3c..5c4ac8a6554a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -67,8 +67,9 @@ logging, numpy_to_pil, ) +from ..utils.distributed_utils import is_torch_dist_rank_zero from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card -from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module, is_torch_dist_rank_zero +from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module if is_torch_npu_available(): diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py new file mode 100644 index 000000000000..239b7b26200d --- /dev/null +++ b/src/diffusers/utils/distributed_utils.py @@ -0,0 +1,36 @@ +# Copyright 2025 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + import torch +except ImportError: + torch = None + + +def is_torch_dist_rank_zero() -> bool: + if torch is None: + return True + + dist_module = getattr(torch, "distributed", None) + if dist_module is None or not dist_module.is_available(): + return True + + if not dist_module.is_initialized(): + return True + + try: + return dist_module.get_rank() == 0 + except (RuntimeError, ValueError): + return True diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 2ad6d3a47607..48a5d900af56 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -32,6 +32,8 @@ from tqdm import auto as tqdm_lib +from .distributed_utils import is_torch_dist_rank_zero + _lock = threading.Lock() _default_handler: Optional[logging.Handler] = None @@ -47,6 +49,22 @@ _default_log_level = logging.WARNING _tqdm_active = True +_rank_zero_filter = None + + +class _RankZeroFilter(logging.Filter): + def filter(self, record): + return is_torch_dist_rank_zero() + + +def _ensure_rank_zero_filter(logger: logging.Logger) -> None: + global _rank_zero_filter + + if _rank_zero_filter is None: + _rank_zero_filter = _RankZeroFilter() + + if not any(isinstance(f, _RankZeroFilter) for f in logger.filters): + logger.addFilter(_rank_zero_filter) def _get_default_logging_level() -> int: @@ -90,6 +108,7 @@ def _configure_library_root_logger() -> None: library_root_logger.addHandler(_default_handler) library_root_logger.setLevel(_get_default_logging_level()) library_root_logger.propagate = False + _ensure_rank_zero_filter(library_root_logger) def _reset_library_root_logger() -> None: @@ -120,7 +139,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: name = _get_library_name() _configure_library_root_logger() - return logging.getLogger(name) + logger = logging.getLogger(name) + _ensure_rank_zero_filter(logger) + return logger def get_verbosity() -> int: diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index b8a61b9cdb1a..3b66fdadbef8 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -143,23 +143,6 @@ def backend_supports_training(device: str): return BACKEND_SUPPORTS_TRAINING[device] -def is_torch_dist_rank_zero() -> bool: - if not is_torch_available(): - return True - - dist_module = getattr(torch, "distributed", None) - if dist_module is None or not dist_module.is_available(): - return True - - if not dist_module.is_initialized(): - return True - - try: - return dist_module.get_rank() == 0 - except (RuntimeError, ValueError): - return True - - def randn_tensor( shape: Union[Tuple, List], generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, From 2529fdf6af4de622066a07aeb204048da00bb727 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 08:31:10 +0530 Subject: [PATCH 5/7] up --- src/diffusers/utils/distributed_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py index 239b7b26200d..154ecab3a5d0 100644 --- a/src/diffusers/utils/distributed_utils.py +++ b/src/diffusers/utils/distributed_utils.py @@ -13,16 +13,15 @@ # limitations under the License. -try: - import torch -except ImportError: - torch = None +from .import_utils import is_torch_available def is_torch_dist_rank_zero() -> bool: - if torch is None: + if not is_torch_available(): return True + import torch + dist_module = getattr(torch, "distributed", None) if dist_module is None or not dist_module.is_available(): return True From 8193f38e1970ce1f695230fe785e3ea1b43f7766 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 9 Dec 2025 08:37:06 +0530 Subject: [PATCH 6/7] up --- src/diffusers/utils/distributed_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/utils/distributed_utils.py b/src/diffusers/utils/distributed_utils.py index 154ecab3a5d0..239b7b26200d 100644 --- a/src/diffusers/utils/distributed_utils.py +++ b/src/diffusers/utils/distributed_utils.py @@ -13,15 +13,16 @@ # limitations under the License. -from .import_utils import is_torch_available +try: + import torch +except ImportError: + torch = None def is_torch_dist_rank_zero() -> bool: - if not is_torch_available(): + if torch is None: return True - import torch - dist_module = getattr(torch, "distributed", None) if dist_module is None or not dist_module.is_available(): return True From 086a770174b8eb1b550dd9f2c7b8a1bc996d7a89 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 10 Dec 2025 08:45:42 +0530 Subject: [PATCH 7/7] up --- src/diffusers/utils/logging.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 48a5d900af56..80e108e4a6ff 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -54,7 +54,8 @@ class _RankZeroFilter(logging.Filter): def filter(self, record): - return is_torch_dist_rank_zero() + # Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting. + return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG def _ensure_rank_zero_filter(logger: logging.Logger) -> None: