|
32 | 32 |
|
33 | 33 | from tqdm import auto as tqdm_lib |
34 | 34 |
|
| 35 | +from .distributed_utils import is_torch_dist_rank_zero |
| 36 | + |
35 | 37 |
|
36 | 38 | _lock = threading.Lock() |
37 | 39 | _default_handler: Optional[logging.Handler] = None |
|
47 | 49 | _default_log_level = logging.WARNING |
48 | 50 |
|
49 | 51 | _tqdm_active = True |
| 52 | +_rank_zero_filter = None |
| 53 | + |
| 54 | + |
| 55 | +class _RankZeroFilter(logging.Filter): |
| 56 | + def filter(self, record): |
| 57 | + return is_torch_dist_rank_zero() |
| 58 | + |
| 59 | + |
| 60 | +def _ensure_rank_zero_filter(logger: logging.Logger) -> None: |
| 61 | + global _rank_zero_filter |
| 62 | + |
| 63 | + if _rank_zero_filter is None: |
| 64 | + _rank_zero_filter = _RankZeroFilter() |
| 65 | + |
| 66 | + if not any(isinstance(f, _RankZeroFilter) for f in logger.filters): |
| 67 | + logger.addFilter(_rank_zero_filter) |
50 | 68 |
|
51 | 69 |
|
52 | 70 | def _get_default_logging_level() -> int: |
@@ -90,6 +108,7 @@ def _configure_library_root_logger() -> None: |
90 | 108 | library_root_logger.addHandler(_default_handler) |
91 | 109 | library_root_logger.setLevel(_get_default_logging_level()) |
92 | 110 | library_root_logger.propagate = False |
| 111 | + _ensure_rank_zero_filter(library_root_logger) |
93 | 112 |
|
94 | 113 |
|
95 | 114 | def _reset_library_root_logger() -> None: |
@@ -120,7 +139,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger: |
120 | 139 | name = _get_library_name() |
121 | 140 |
|
122 | 141 | _configure_library_root_logger() |
123 | | - return logging.getLogger(name) |
| 142 | + logger = logging.getLogger(name) |
| 143 | + _ensure_rank_zero_filter(logger) |
| 144 | + return logger |
124 | 145 |
|
125 | 146 |
|
126 | 147 | def get_verbosity() -> int: |
|
0 commit comments