Skip to content

Commit 23f2ab3

Browse files
yuxinyuanYU XinyuanSunMarcgithub-actions[bot]
authored
Fix logging logic when in_order is set to True (#3280)
* Fix logging logic when in_order is set to True * Fix annotation for py39 * Fix logging related tests * Apply style fixes --------- Co-authored-by: YU Xinyuan <yuxinyuan02@corp.netease.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 58c3605 commit 23f2ab3

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

src/accelerate/logging.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

1516
import functools
1617
import logging
1718
import os
18-
from typing import Optional
1919

2020
from .state import PartialState
2121

@@ -36,6 +36,16 @@ def _should_log(main_process_only):
3636
state = PartialState()
3737
return not main_process_only or (main_process_only and state.is_main_process)
3838

39+
def process(self, msg, kwargs):
40+
msg, kwargs = super().process(msg, kwargs)
41+
42+
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
43+
kwargs.setdefault("stacklevel", 2)
44+
45+
state = PartialState()
46+
msg = f"[RANK {state.process_index}] {msg}"
47+
return msg, kwargs
48+
3949
def log(self, level, msg, *args, **kwargs):
4050
"""
4151
Delegates logger call after checking if we should log.
@@ -47,27 +57,24 @@ def log(self, level, msg, *args, **kwargs):
4757
read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
4858
break with the previous behavior.
4959
50-
`in_order` is ignored if `main_process_only` is passed.
60+
`main_process_only` is ignored if `in_order` is passed.
5161
"""
5262
if PartialState._shared_state == {}:
5363
raise RuntimeError(
5464
"You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
5565
)
5666
main_process_only = kwargs.pop("main_process_only", True)
5767
in_order = kwargs.pop("in_order", False)
58-
# set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
59-
kwargs.setdefault("stacklevel", 2)
6068

6169
if self.isEnabledFor(level):
62-
if self._should_log(main_process_only):
63-
msg, kwargs = self.process(msg, kwargs)
70+
msg, kwargs = self.process(msg, kwargs)
71+
if not in_order and self._should_log(main_process_only):
6472
self.logger.log(level, msg, *args, **kwargs)
6573

6674
elif in_order:
6775
state = PartialState()
6876
for i in range(state.num_processes):
6977
if i == state.process_index:
70-
msg, kwargs = self.process(msg, kwargs)
7178
self.logger.log(level, msg, *args, **kwargs)
7279
state.wait_for_everyone()
7380

@@ -83,7 +90,7 @@ def warning_once(self, *args, **kwargs):
8390
self.warning(*args, **kwargs)
8491

8592

86-
def get_logger(name: str, log_level: Optional[str] = None):
93+
def get_logger(name: str, log_level: str | None = None):
8794
"""
8895
Returns a `logging.Logger` for `name` that can handle multiprocessing.
8996

tests/test_logging.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def test_log_stack(caplog):
5656
)
5757

5858
message = "Test"
59+
expected_message, _ = logger.process(message, {})
5960
lineno = current_lineno() + 1 # the next line is the actual callsite
6061
logger.warning(message)
6162

@@ -66,7 +67,7 @@ def test_log_stack(caplog):
6667
assert rec.name == __name__
6768
assert rec.lineno == lineno
6869
assert rec.funcName == test_log_stack.__name__
69-
assert rec.message == message
70+
assert rec.message == expected_message
7071

7172

7273
@pytest.mark.usefixtures("accelerator")
@@ -79,6 +80,7 @@ def test_custom_stacklevel(caplog):
7980
logger = CustomLogger(wrapped_logger, {})
8081

8182
message = "Test"
83+
expected_message, _ = wrapped_logger.process(message, {})
8284
lineno = current_lineno() + 1 # the next line is the actual callsite
8385
logger.warning(message)
8486

@@ -91,4 +93,4 @@ def test_custom_stacklevel(caplog):
9193
assert rec.name == __name__
9294
assert rec.lineno == lineno
9395
assert rec.funcName == test_custom_stacklevel.__name__
94-
assert rec.message == message
96+
assert rec.message == expected_message

0 commit comments

Comments
 (0)