Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
metrics:
logger: tensorboard
freq:
loss: 10

# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
Expand Down
4 changes: 3 additions & 1 deletion apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX
from forge.util import get_metric_logger

from omegaconf import DictConfig, OmegaConf
from torch import nn
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(self, job_config: ForgeJobConfig):
self.num_training_steps = job_config.training.steps
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
super().__init__(job_config)
self.metric_logger = None # TODO: fix this
self.metric_logger = get_metric_logger(**job_config.metrics)

def setup(self):
self.train_dataloader = self.setup_data(
Expand Down Expand Up @@ -203,6 +204,7 @@ def train_step(self, batch) -> None:
loss = self.forward_backward(batch, labels)
self.pbar.update(1)
self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
self.metric_logger.log("loss", loss.item(), self.current_step)

self.optimizers.step()
self.lr_schedulers.step()
Expand Down
51 changes: 48 additions & 3 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any

from forge.types import Action, Message, Observation, State
from typing import Any, Mapping

from monarch.actor import Actor, endpoint

from forge.types import Action, Message, Observation, Scalar, State


class Transform(ABC):
"""Abstract base class for observation transforms.
Expand Down Expand Up @@ -152,6 +152,51 @@ def tokenize_messages(
pass


class MetricLogger(ABC):
"""Abstract metric logger."""

@abstractmethod
def is_log_step(self, name: str, step: int) -> bool:
"""Returns true if the current step is a logging step.

Args:
name (str): metric name (for checking the freq for this metric)
step (int): current step
"""
pass

@abstractmethod
def log(self, name: str, data: Scalar, step: int) -> None:
"""Log scalar data if this is a logging step.

Args:
name (str): tag name used to group scalars
data (Scalar): scalar data to log
step (int): step value to record
"""
pass

@abstractmethod
def log_dict(self, metrics: Mapping[str, Scalar], step: int) -> None:
"""Log multiple scalar values if this is a logging step.

Args:
metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value
step (int): step value to record
"""
pass

def __del__(self) -> None:
self.close()

def close(self) -> None:
"""
Close log resource, flushing if necessary.
This will automatically be called via __del__ when the instance goes out of scope.
Logs should not be written after `close` is called.
"""


class Reward(ABC):
"""Abstract base class for reward models."""

Expand Down
5 changes: 4 additions & 1 deletion src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Any, Literal, TypedDict
from typing import Any, Literal, TypedDict, Union


class Message(TypedDict):
Expand Down Expand Up @@ -130,3 +130,6 @@ def to_process_config(self) -> ProcessConfig:
identity=self.identity,
image=self.image,
)


Scalar = Union[int, float]
16 changes: 16 additions & 0 deletions src/forge/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .distributed import get_world_size_and_rank
from .logging import get_logger, log_once, log_rank_zero
from .metric_logging import get_metric_logger

__all__ = [
"get_world_size_and_rank",
"get_logger",
"log_once",
"log_rank_zero",
"get_metric_logger",
]
20 changes: 20 additions & 0 deletions src/forge/util/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


def get_world_size_and_rank() -> tuple[int, int]:
"""Function that gets the current world size (aka total number
of ranks) and rank number of the current process in the default process group.

Returns:
tuple[int, int]: world size, rank
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size(), torch.distributed.get_rank()
else:
return 1, 0
69 changes: 69 additions & 0 deletions src/forge/util/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from functools import lru_cache
from typing import Optional, TypeVar

from torch import distributed as dist

T = TypeVar("T", bound=type)


def get_logger(level: Optional[str] = None) -> logging.Logger:
"""
Get a logger with a stream handler.
Args:
level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Hello world!")
INFO:torchtune.utils._logging:Hello world!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
INFO:torchtune.utils._logging:Hello world!
INFO:forge.util.logging:Hello world!

Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger


def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None:
"""
Logs a message only on rank zero.
Args:
logger (logging.Logger): The logger.
msg (str): The warning message.
level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values.
Defaults to ``logging.INFO``.
"""
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if rank != 0:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't log rank be configurable like how titan did? But I guess we can always improve later.

logger.log(level, msg, stacklevel=2)


@lru_cache(None)
def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None:
"""
Logs a message only once. LRU cache is used to ensure a specific message is
logged only once, similar to how :func:`~warnings.warn` works when the ``once``
rule is set via command-line or environment variable.
Args:
logger (logging.Logger): The logger.
msg (str): The warning message.
level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values.
Defaults to ``logging.INFO``.
"""
log_rank_zero(logger=logger, msg=msg, level=level)
Loading
Loading