66import os
77import sys
88import time
9- from typing import Mapping , Optional
9+ from typing import Mapping , Optional , Union
1010
1111from forge .interfaces import MetricLogger
1212from forge .types import Scalar
@@ -21,11 +21,12 @@ class StdoutLogger(MetricLogger):
2121 """Logger to standard output.
2222
2323 Args:
24- freq (Mapping[str, int]):
25- calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
24+ freq (Union[int, Mapping[str, int]]):
25+ If int, all metrics will be logged at this frequency.
26+ If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
2627 """
2728
28- def __init__ (self , freq : Mapping [str , int ]):
29+ def __init__ (self , freq : Union [ int , Mapping [str , int ] ]):
2930 self ._freq = freq
3031
3132 def is_log_step (self , name : str , step : int ) -> bool :
@@ -35,6 +36,8 @@ def is_log_step(self, name: str, step: int) -> bool:
3536 name (str): metric name (for checking the freq for this metric)
3637 step (int): current step
3738 """
39+ if isinstance (self ._freq , int ):
40+ return step % self ._freq == 0
3841 return step % self ._freq [name ] == 0
3942
4043 def log (self , name : str , data : Scalar , step : int ) -> None :
@@ -77,8 +80,9 @@ class TensorBoardLogger(MetricLogger):
7780 """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).
7881
7982 Args:
80- freq (Mapping[str, int]):
81- calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
83+ freq (Union[int, Mapping[str, int]]):
84+ If int, all metrics will be logged at this frequency.
85+ If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
8286 log_dir (str): torch.TensorBoard log directory
8387 organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current
8488 run. Having sub-directories allows you to compare logs across runs. When TensorBoard is
@@ -103,7 +107,7 @@ class TensorBoardLogger(MetricLogger):
103107
104108 def __init__ (
105109 self ,
106- freq : Mapping [str , int ],
110+ freq : Union [ int , Mapping [str , int ] ],
107111 log_dir : str = "metrics_log" ,
108112 organize_logs : bool = True ,
109113 ** kwargs ,
@@ -133,6 +137,8 @@ def is_log_step(self, name: str, step: int) -> bool:
133137 name (str): metric name (for checking the freq for this metric)
134138 step (int): current step
135139 """
140+ if isinstance (self ._freq , int ):
141+ return step % self ._freq == 0
136142 return step % self ._freq [name ] == 0
137143
138144 def log (self , name : str , data : Scalar , step : int ) -> None :
@@ -168,8 +174,9 @@ class WandBLogger(MetricLogger):
168174 For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init.
169175
170176 Args:
171- freq (Mapping[str, int]):
172- calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
177+ freq (Union[int, Mapping[str, int]]):
178+ If int, all metrics will be logged at this frequency.
179+ If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
173180 log_dir (Optional[str]): WandB log directory.
174181 project (str): WandB project name. Default is `torchtune`.
175182 entity (Optional[str]): WandB entity name. If you don't specify an entity,
@@ -197,7 +204,7 @@ class WandBLogger(MetricLogger):
197204
198205 def __init__ (
199206 self ,
200- freq : Mapping [str , int ],
207+ freq : Union [ int , Mapping [str , int ] ],
201208 project : str ,
202209 log_dir : str = "metrics_log" ,
203210 entity : Optional [str ] = None ,
@@ -241,6 +248,8 @@ def is_log_step(self, name: str, step: int) -> bool:
241248 name (str): metric name (for checking the freq for this metric)
242249 step (int): current step
243250 """
251+ if isinstance (self ._freq , int ):
252+ return step % self ._freq == 0
244253 return step % self ._freq [name ] == 0
245254
246255 def log (self , name : str , data : Scalar , step : int ) -> None :
0 commit comments