1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from __future__ import annotations
1415
1516import functools
1617import logging
1718import os
18- from typing import Optional
1919
2020from .state import PartialState
2121
@@ -36,6 +36,16 @@ def _should_log(main_process_only):
3636 state = PartialState ()
3737 return not main_process_only or (main_process_only and state .is_main_process )
3838
39+ def process (self , msg , kwargs ):
40+ msg , kwargs = super ().process (msg , kwargs )
41+
42+ # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
43+ kwargs .setdefault ("stacklevel" , 2 )
44+
45+ state = PartialState ()
46+ msg = f"[RANK { state .process_index } ] { msg } "
47+ return msg , kwargs
48+
3949 def log (self , level , msg , * args , ** kwargs ):
4050 """
4151 Delegates logger call after checking if we should log.
@@ -47,27 +57,24 @@ def log(self, level, msg, *args, **kwargs):
4757 read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
4858 break with the previous behavior.
4959
50- `in_order ` is ignored if `main_process_only ` is passed.
60+ `main_process_only ` is ignored if `in_order ` is passed.
5161 """
5262 if PartialState ._shared_state == {}:
5363 raise RuntimeError (
5464 "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
5565 )
5666 main_process_only = kwargs .pop ("main_process_only" , True )
5767 in_order = kwargs .pop ("in_order" , False )
58- # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
59- kwargs .setdefault ("stacklevel" , 2 )
6068
6169 if self .isEnabledFor (level ):
62- if self ._should_log ( main_process_only ):
63- msg , kwargs = self .process ( msg , kwargs )
70+ msg , kwargs = self .process ( msg , kwargs )
71+ if not in_order and self ._should_log ( main_process_only ):
6472 self .logger .log (level , msg , * args , ** kwargs )
6573
6674 elif in_order :
6775 state = PartialState ()
6876 for i in range (state .num_processes ):
6977 if i == state .process_index :
70- msg , kwargs = self .process (msg , kwargs )
7178 self .logger .log (level , msg , * args , ** kwargs )
7279 state .wait_for_everyone ()
7380
@@ -83,7 +90,7 @@ def warning_once(self, *args, **kwargs):
8390 self .warning (* args , ** kwargs )
8491
8592
86- def get_logger (name : str , log_level : Optional [ str ] = None ):
93+ def get_logger (name : str , log_level : str | None = None ):
8794 """
8895 Returns a `logging.Logger` for `name` that can handle multiprocessing.
8996
0 commit comments