@@ -38,11 +38,24 @@ def is_log_step(self, name: str, step: int) -> bool:
3838 return step % self ._freq [name ] == 0
3939
4040 def log (self , name : str , data : Scalar , step : int ) -> None :
41+ """Log the metric if it is a logging step.
42+
43+ Args:
44+ name (str): metric name
45+ data (Scalar): metric value
46+ step (int): current step
47+ """
4148 if not self .is_log_step (name , step ):
4249 return
4350 print (f"Step { step } | { name } :{ data } " )
4451
4552 def log_dict (self , metrics : Mapping [str , Scalar ], step : int ) -> None :
53+ """Log the metrics for which this is currently a logging step.
54+
55+ Args:
56+ metrics (Mapping[str, Scalar]): dict of metric names and values
57+ step (int): current step
58+ """
4659 log_step_metrics = {
4760 name : value
4861 for name , value in metrics .items ()
@@ -76,8 +89,8 @@ class TensorBoardLogger(MetricLogger):
7689 **kwargs: additional arguments
7790
7891 Example:
79- >>> from torchtune.training .metric_logging import TensorBoardLogger
80- >>> logger = TensorBoardLogger(log_dir="my_log_dir")
92+ >>> from forge.util .metric_logging import TensorBoardLogger
93+ >>> logger = TensorBoardLogger(freq={"loss": 10}, log_dir="my_log_dir")
8194 >>> logger.log("my_metric", 1.0, 1)
8295 >>> logger.log_dict({"my_metric": 1.0}, 1)
8396 >>> logger.close()
@@ -123,10 +136,23 @@ def is_log_step(self, name: str, step: int) -> bool:
123136 return step % self ._freq [name ] == 0
124137
125138 def log (self , name : str , data : Scalar , step : int ) -> None :
139+ """Log the metric if it is a logging step.
140+
141+ Args:
142+ name (str): metric name
143+ data (Scalar): metric value
144+ step (int): current step
145+ """
126146 if self ._writer :
127147 self ._writer .add_scalar (name , data , global_step = step , new_style = True )
128148
129149 def log_dict (self , metrics : Mapping [str , Scalar ], step : int ) -> None :
150+ """Log the metrics for which this is currently a logging step.
151+
152+ Args:
153+ metrics (Mapping[str, Scalar]): dict of metric names and values
154+ step (int): current step
155+ """
130156 for name , data in metrics .items ():
131157 if self .is_log_step (name , step ):
132158 self .log (name , data , step )
@@ -153,8 +179,8 @@ class WandBLogger(MetricLogger):
153179 **kwargs: additional arguments to pass to wandb.init
154180
155181 Example:
156- >>> from torchtune.training .metric_logging import WandBLogger
157- >>> logger = WandBLogger(log_dir="wandb", project="my_project", entity="my_entity ", group="my_group ")
182+ >>> from forge.util .metric_logging import WandBLogger
183+ >>> logger = WandBLogger(freq={"loss": 10}, log_dir="wandb ", project="my_project ")
158184 >>> logger.log("my_metric", 1.0, 1)
159185 >>> logger.log_dict({"my_metric": 1.0}, 1)
160186 >>> logger.close()
@@ -218,10 +244,23 @@ def is_log_step(self, name: str, step: int) -> bool:
218244 return step % self ._freq [name ] == 0
219245
220246 def log (self , name : str , data : Scalar , step : int ) -> None :
247+ """Log the metric if it is a logging step.
248+
249+ Args:
250+ name (str): metric name
251+ data (Scalar): metric value
252+ step (int): current step
253+ """
221254 if self ._wandb .run and self .is_log_step (name , step ):
222255 self ._wandb .log ({name : data , "step" : step })
223256
224257 def log_dict (self , metrics : Mapping [str , Scalar ], step : int ) -> None :
258+ """Log the metrics for which this is currently a logging step.
259+
260+ Args:
261+ metrics (Mapping[str, Scalar]): dict of metric names and values
262+ step (int): current step
263+ """
225264 log_step_metrics = {
226265 name : value
227266 for name , value in metrics .items ()
0 commit comments