Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions src/forge/util/metric_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import sys
import time
from typing import Mapping, Optional
from typing import Mapping, Optional, Union

from forge.interfaces import MetricLogger
from forge.types import Scalar
Expand All @@ -21,11 +21,12 @@ class StdoutLogger(MetricLogger):
"""Logger to standard output.

Args:
freq (Mapping[str, int]):
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
freq (Union[int, Mapping[str, int]]):
If int, all metrics will be logged at this frequency.
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
"""

def __init__(self, freq: Mapping[str, int]):
def __init__(self, freq: Union[int, Mapping[str, int]]):
self._freq = freq

def is_log_step(self, name: str, step: int) -> bool:
Expand All @@ -35,6 +36,8 @@ def is_log_step(self, name: str, step: int) -> bool:
name (str): metric name (for checking the freq for this metric)
step (int): current step
"""
if isinstance(self._freq, int):
return step % self._freq == 0
return step % self._freq[name] == 0

def log(self, name: str, data: Scalar, step: int) -> None:
Expand Down Expand Up @@ -77,8 +80,9 @@ class TensorBoardLogger(MetricLogger):
"""Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).

Args:
freq (Mapping[str, int]):
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
freq (Union[int, Mapping[str, int]]):
If int, all metrics will be logged at this frequency.
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
log_dir (str): torch.TensorBoard log directory
organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current
run. Having sub-directories allows you to compare logs across runs. When TensorBoard is
Expand All @@ -103,7 +107,7 @@ class TensorBoardLogger(MetricLogger):

def __init__(
self,
freq: Mapping[str, int],
freq: Union[int, Mapping[str, int]],
log_dir: str = "metrics_log",
organize_logs: bool = True,
**kwargs,
Expand Down Expand Up @@ -133,6 +137,8 @@ def is_log_step(self, name: str, step: int) -> bool:
name (str): metric name (for checking the freq for this metric)
step (int): current step
"""
if isinstance(self._freq, int):
return step % self._freq == 0
return step % self._freq[name] == 0

def log(self, name: str, data: Scalar, step: int) -> None:
Expand Down Expand Up @@ -168,8 +174,9 @@ class WandBLogger(MetricLogger):
For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init.

Args:
freq (Mapping[str, int]):
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
freq (Union[int, Mapping[str, int]]):
If int, all metrics will be logged at this frequency.
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
log_dir (Optional[str]): WandB log directory.
project (str): WandB project name. Default is `torchtune`.
entity (Optional[str]): WandB entity name. If you don't specify an entity,
Expand Down Expand Up @@ -197,7 +204,7 @@ class WandBLogger(MetricLogger):

def __init__(
self,
freq: Mapping[str, int],
freq: Union[int, Mapping[str, int]],
project: str,
log_dir: str = "metrics_log",
entity: Optional[str] = None,
Expand Down Expand Up @@ -241,6 +248,8 @@ def is_log_step(self, name: str, step: int) -> bool:
name (str): metric name (for checking the freq for this metric)
step (int): current step
"""
if isinstance(self._freq, int):
return step % self._freq == 0
return step % self._freq[name] == 0

def log(self, name: str, data: Scalar, step: int) -> None:
Expand Down
Loading