Skip to content

Commit 94d671d

Browse files
committed
up
1 parent 8df3fbc commit 94d671d

File tree

6 files changed

+69
-26
lines changed

6 files changed

+69
-26
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
is_torch_version,
4848
logging,
4949
)
50+
from ..utils.distributed_utils import is_torch_dist_rank_zero
5051

5152

5253
logger = logging.get_logger(__name__)
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
429430
low_cpu_mem_usage=low_cpu_mem_usage,
430431
)
431432

433+
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
434+
if not is_torch_dist_rank_zero():
435+
tqdm_kwargs["disable"] = True
436+
432437
with ThreadPoolExecutor(max_workers=num_workers) as executor:
433-
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
438+
with logging.tqdm(**tqdm_kwargs) as pbar:
434439
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
435440
for future in as_completed(futures):
436441
result = future.result()

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,9 @@
5959
is_torch_version,
6060
logging,
6161
)
62-
from ..utils.hub_utils import (
63-
PushToHubMixin,
64-
load_or_create_model_card,
65-
populate_model_card,
66-
)
67-
from ..utils.torch_utils import empty_device_cache, is_torch_dist_rank_zero
62+
from ..utils.distributed_utils import is_torch_dist_rank_zero
63+
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
64+
from ..utils.torch_utils import empty_device_cache
6865
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
6966
from .model_loading_utils import (
7067
_caching_allocator_warmup,

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,9 @@
6767
logging,
6868
numpy_to_pil,
6969
)
70+
from ..utils.distributed_utils import is_torch_dist_rank_zero
7071
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
71-
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module, is_torch_dist_rank_zero
72+
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
7273

7374

7475
if is_torch_npu_available():
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2025 The HuggingFace Inc. team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
try:
17+
import torch
18+
except ImportError:
19+
torch = None
20+
21+
22+
def is_torch_dist_rank_zero() -> bool:
23+
if torch is None:
24+
return True
25+
26+
dist_module = getattr(torch, "distributed", None)
27+
if dist_module is None or not dist_module.is_available():
28+
return True
29+
30+
if not dist_module.is_initialized():
31+
return True
32+
33+
try:
34+
return dist_module.get_rank() == 0
35+
except (RuntimeError, ValueError):
36+
return True

src/diffusers/utils/logging.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232

3333
from tqdm import auto as tqdm_lib
3434

35+
from .distributed_utils import is_torch_dist_rank_zero
36+
3537

3638
_lock = threading.Lock()
3739
_default_handler: Optional[logging.Handler] = None
@@ -47,6 +49,22 @@
4749
_default_log_level = logging.WARNING
4850

4951
_tqdm_active = True
52+
_rank_zero_filter = None
53+
54+
55+
class _RankZeroFilter(logging.Filter):
56+
def filter(self, record):
57+
return is_torch_dist_rank_zero()
58+
59+
60+
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
61+
global _rank_zero_filter
62+
63+
if _rank_zero_filter is None:
64+
_rank_zero_filter = _RankZeroFilter()
65+
66+
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
67+
logger.addFilter(_rank_zero_filter)
5068

5169

5270
def _get_default_logging_level() -> int:
@@ -90,6 +108,7 @@ def _configure_library_root_logger() -> None:
90108
library_root_logger.addHandler(_default_handler)
91109
library_root_logger.setLevel(_get_default_logging_level())
92110
library_root_logger.propagate = False
111+
_ensure_rank_zero_filter(library_root_logger)
93112

94113

95114
def _reset_library_root_logger() -> None:
@@ -120,7 +139,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
120139
name = _get_library_name()
121140

122141
_configure_library_root_logger()
123-
return logging.getLogger(name)
142+
logger = logging.getLogger(name)
143+
_ensure_rank_zero_filter(logger)
144+
return logger
124145

125146

126147
def get_verbosity() -> int:

src/diffusers/utils/torch_utils.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,6 @@ def backend_supports_training(device: str):
143143
return BACKEND_SUPPORTS_TRAINING[device]
144144

145145

146-
def is_torch_dist_rank_zero() -> bool:
147-
if not is_torch_available():
148-
return True
149-
150-
dist_module = getattr(torch, "distributed", None)
151-
if dist_module is None or not dist_module.is_available():
152-
return True
153-
154-
if not dist_module.is_initialized():
155-
return True
156-
157-
try:
158-
return dist_module.get_rank() == 0
159-
except (RuntimeError, ValueError):
160-
return True
161-
162-
163146
def randn_tensor(
164147
shape: Union[Tuple, List],
165148
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,

0 commit comments

Comments
 (0)