diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 37c24b69d..567e4bad1 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -2,10 +2,10 @@ # profiling: # enable_profiling: false -# metrics: -# log_freq: 10 -# enable_tensorboard: true -# save_tb_folder: "tb" +metrics: + logger: tensorboard + freq: + loss: 10 # TODO: required by torchtitan # https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 diff --git a/apps/sft/main.py b/apps/sft/main.py index b5ae6fc16..2933ab684 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -19,6 +19,7 @@ from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX +from forge.util import get_metric_logger from omegaconf import DictConfig, OmegaConf from torch import nn @@ -63,7 +64,7 @@ def __init__(self, job_config: ForgeJobConfig): self.num_training_steps = job_config.training.steps self.gradient_accumulation_steps = 1 # Example value, adjust as needed super().__init__(job_config) - self.metric_logger = None # TODO: fix this + self.metric_logger = get_metric_logger(**job_config.metrics) def setup(self): self.train_dataloader = self.setup_data( @@ -203,6 +204,7 @@ def train_step(self, batch) -> None: loss = self.forward_backward(batch, labels) self.pbar.update(1) self.pbar.set_description(f"{self.current_step}|Loss: {loss}") + self.metric_logger.log("loss", loss.item(), self.current_step) self.optimizers.step() self.lr_schedulers.step() diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index b485fc791..76843bce2 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. from abc import ABC, abstractmethod -from typing import Any - -from forge.types import Action, Message, Observation, State +from typing import Any, Mapping from monarch.actor import Actor, endpoint +from forge.types import Action, Message, Observation, Scalar, State + class Transform(ABC): """Abstract base class for observation transforms. @@ -152,6 +152,51 @@ def tokenize_messages( pass +class MetricLogger(ABC): + """Abstract metric logger.""" + + @abstractmethod + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + pass + + @abstractmethod + def log(self, name: str, data: Scalar, step: int) -> None: + """Log scalar data if this is a logging step. + + Args: + name (str): tag name used to group scalars + data (Scalar): scalar data to log + step (int): step value to record + """ + pass + + @abstractmethod + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log multiple scalar values if this is a logging step. + + Args: + metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value + step (int): step value to record + """ + pass + + def __del__(self) -> None: + self.close() + + def close(self) -> None: + """ + Close log resource, flushing if necessary. + This will automatically be called via __del__ when the instance goes out of scope. + Logs should not be written after `close` is called. + """ + + class Reward(ABC): """Abstract base class for reward models.""" diff --git a/src/forge/types.py b/src/forge/types.py index 931b56c45..7d23b7838 100644 --- a/src/forge/types.py +++ b/src/forge/types.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field -from typing import Any, Literal, TypedDict +from typing import Any, Literal, TypedDict, Union class Message(TypedDict): @@ -130,3 +130,6 @@ def to_process_config(self) -> ProcessConfig: identity=self.identity, image=self.image, ) + + +Scalar = Union[int, float] diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py new file mode 100644 index 000000000..5fb03b0f9 --- /dev/null +++ b/src/forge/util/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from .distributed import get_world_size_and_rank +from .logging import get_logger, log_once, log_rank_zero +from .metric_logging import get_metric_logger + +__all__ = [ + "get_world_size_and_rank", + "get_logger", + "log_once", + "log_rank_zero", + "get_metric_logger", +] diff --git a/src/forge/util/distributed.py b/src/forge/util/distributed.py new file mode 100644 index 000000000..b32be7291 --- /dev/null +++ b/src/forge/util/distributed.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def get_world_size_and_rank() -> tuple[int, int]: + """Function that gets the current world size (aka total number + of ranks) and rank number of the current process in the default process group. + + Returns: + tuple[int, int]: world size, rank + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return torch.distributed.get_world_size(), torch.distributed.get_rank() + else: + return 1, 0 diff --git a/src/forge/util/logging.py b/src/forge/util/logging.py new file mode 100644 index 000000000..e53218ccd --- /dev/null +++ b/src/forge/util/logging.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from functools import lru_cache +from typing import Optional, TypeVar + +from torch import distributed as dist + +T = TypeVar("T", bound=type) + + +def get_logger(level: Optional[str] = None) -> logging.Logger: + """ + Get a logger with a stream handler. + + Args: + level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. + + Example: + >>> logger = get_logger("INFO") + >>> logger.info("Hello world!") + INFO:torchtune.utils._logging:Hello world! + + Returns: + logging.Logger: The logger. + """ + logger = logging.getLogger(__name__) + if not logger.hasHandlers(): + logger.addHandler(logging.StreamHandler()) + if level is not None: + level = getattr(logging, level.upper()) + logger.setLevel(level) + return logger + + +def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only on rank zero. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + if rank != 0: + return + logger.log(level, msg, stacklevel=2) + + +@lru_cache(None) +def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: + """ + Logs a message only once. LRU cache is used to ensure a specific message is + logged only once, similar to how :func:`~warnings.warn` works when the ``once`` + rule is set via command-line or environment variable. + + Args: + logger (logging.Logger): The logger. + msg (str): The warning message. + level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. + Defaults to ``logging.INFO``. + """ + log_rank_zero(logger=logger, msg=msg, level=level) diff --git a/src/forge/util/metric_logging.py b/src/forge/util/metric_logging.py new file mode 100644 index 000000000..75790c813 --- /dev/null +++ b/src/forge/util/metric_logging.py @@ -0,0 +1,285 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import os +import sys +import time +from typing import Mapping, Optional + +from forge.interfaces import MetricLogger +from forge.types import Scalar +from forge.util.distributed import get_world_size_and_rank + + +def get_metric_logger(logger: str = "stdout", **log_config): + return METRIC_LOGGER_STR_TO_CLS[logger](**log_config) + + +class StdoutLogger(MetricLogger): + """Logger to standard output. + + Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` + """ + + def __init__(self, freq: Mapping[str, int]): + self._freq = freq + + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 + + def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ + if not self.is_log_step(name, step): + return + print(f"Step {step} | {name}:{data}") + + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ + log_step_metrics = { + name: value + for name, value in metrics.items() + if self.is_log_step(name, step) + } + if not log_step_metrics: + return + + print(f"Step {step} | ", end="") + for name, data in log_step_metrics.items(): + print(f"{name}:{data} ", end="") + print("\n", end="") + + def close(self) -> None: + sys.stdout.flush() + + +class TensorBoardLogger(MetricLogger): + """Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). + + Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` + log_dir (str): torch.TensorBoard log directory + organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current + run. Having sub-directories allows you to compare logs across runs. When TensorBoard is + passed a logdir at startup, it recursively walks the directory tree rooted at logdir looking for + subdirectories that contain tfevents data. Every time it encounters such a subdirectory, + it loads it as a new run, and the frontend will organize the data accordingly. + Recommended value is `True`. Run `tensorboard --logdir my_log_dir` to view the logs. + **kwargs: additional arguments + + Example: + >>> from forge.util.metric_logging import TensorBoardLogger + >>> logger = TensorBoardLogger(freq={"loss": 10}, log_dir="my_log_dir") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Note: + This utility requires the tensorboard package to be installed. + You can install it with `pip install tensorboard`. + In order to view TensorBoard logs, you need to run `tensorboard --logdir my_log_dir` in your terminal. + """ + + def __init__( + self, + freq: Mapping[str, int], + log_dir: str = "metrics_log", + organize_logs: bool = True, + **kwargs, + ): + from torch.utils.tensorboard import SummaryWriter + + self._freq = freq + self._writer: Optional[SummaryWriter] = None + _, rank = get_world_size_and_rank() + + # In case organize_logs is `True`, update log_dir to include a subdirectory for the + # current run + self.log_dir = ( + os.path.join(log_dir, f"run_{rank}_{time.time()}") + if organize_logs + else log_dir + ) + + # Initialize the log writer only if we're on rank 0. + if rank == 0: + self._writer = SummaryWriter(log_dir=self.log_dir) + + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 + + def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ + if self._writer: + self._writer.add_scalar(name, data, global_step=step, new_style=True) + + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ + for name, data in metrics.items(): + if self.is_log_step(name, step): + self.log(name, data, step) + + def close(self) -> None: + if self._writer: + self._writer.close() + self._writer = None + + +class WandBLogger(MetricLogger): + """Logger for use w/ Weights and Biases application (https://wandb.ai/). + For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init. + + Args: + freq (Mapping[str, int]): + calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0` + log_dir (Optional[str]): WandB log directory. + project (str): WandB project name. Default is `torchtune`. + entity (Optional[str]): WandB entity name. If you don't specify an entity, + the run will be sent to your default entity, which is usually your username. + group (Optional[str]): WandB group name for grouping runs together. If you don't + specify a group, the run will be logged as an individual experiment. + **kwargs: additional arguments to pass to wandb.init + + Example: + >>> from forge.util.metric_logging import WandBLogger + >>> logger = WandBLogger(freq={"loss": 10}, log_dir="wandb", project="my_project") + >>> logger.log("my_metric", 1.0, 1) + >>> logger.log_dict({"my_metric": 1.0}, 1) + >>> logger.close() + + Raises: + ImportError: If ``wandb`` package is not installed. + + Note: + This logger requires the wandb package to be installed. + You can install it with `pip install wandb`. + In order to use the logger, you need to login to your WandB account. + You can do this by running `wandb login` in your terminal. + """ + + def __init__( + self, + freq: Mapping[str, int], + project: str, + log_dir: str = "metrics_log", + entity: Optional[str] = None, + group: Optional[str] = None, + **kwargs, + ): + self._freq = freq + + try: + import wandb + except ImportError as e: + raise ImportError( + "``wandb`` package not found. Please install wandb using `pip install wandb` to use WandBLogger." + ) from e + self._wandb = wandb + + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + _, rank = get_world_size_and_rank() + if self._wandb.run is None and rank == 0: + # we check if wandb.init got called externally + run = self._wandb.init( + project=project, + entity=entity, + group=group, + dir=log_dir, + **kwargs, + ) + + if self._wandb.run: + # define default x-axis (for latest wandb versions) + if getattr(self._wandb, "define_metric", None): + self._wandb.define_metric("step") + self._wandb.define_metric("*", step_metric="step", step_sync=True) + + def is_log_step(self, name: str, step: int) -> bool: + """Returns true if the current step is a logging step. + + Args: + name (str): metric name (for checking the freq for this metric) + step (int): current step + """ + return step % self._freq[name] == 0 + + def log(self, name: str, data: Scalar, step: int) -> None: + """Log the metric if it is a logging step. + + Args: + name (str): metric name + data (Scalar): metric value + step (int): current step + """ + if self._wandb.run and self.is_log_step(name, step): + self._wandb.log({name: data, "step": step}) + + def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None: + """Log the metrics for which this is currently a logging step. + + Args: + metrics (Mapping[str, Scalar]): dict of metric names and values + step (int): current step + """ + log_step_metrics = { + name: value + for name, value in metrics.items() + if self.is_log_step(name, step) + } + if not log_step_metrics: + return + + if self._wandb.run: + self._wandb.log({**metrics, "step": step}) + + def close(self) -> None: + if hasattr(self, "_wandb") and self._wandb.run: + self._wandb.finish() + + +# TODO: replace with direct instantiation via a path to the class in the config +METRIC_LOGGER_STR_TO_CLS = { + "stdout": StdoutLogger, + "tensorboard": TensorBoardLogger, + "wandb": WandBLogger, +}