Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
metrics:
logger: "stdout"
log_freq:
loss: 10

# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
Expand Down
4 changes: 3 additions & 1 deletion apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.util import get_metric_logger

from omegaconf import DictConfig, OmegaConf
from torch import nn
Expand Down Expand Up @@ -60,7 +61,7 @@ def __init__(self, job_config: ForgeJobConfig):
self.num_training_steps = job_config.training.steps
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
super().__init__(job_config)
self.metric_logger = None # TODO: fix this
self.metric_logger = get_metric_logger(**job_config.metrics)

def setup(self):
self.train_dataloader = self.setup_data()
Expand Down Expand Up @@ -185,6 +186,7 @@ def train_step(self, batch) -> None:
loss = self.forward_backward(batch, labels)
self.pbar.update(1)
self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
self.metric_logger.log("loss", loss, self.current_step)

self.optimizers.step()
self.lr_schedulers.step()
Expand Down
8 changes: 4 additions & 4 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
metrics:
logger: "stdout"
log_freq:
loss: 10

# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
Expand Down
5 changes: 4 additions & 1 deletion apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.util import get_metric_logger

from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -74,7 +75,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
def __init__(self, job_config: ForgeJobConfig):
self.current_step = 0
self.num_training_steps = job_config.training.steps
self.metric_logger = None # TODO: fix this
self.metric_logger = get_metric_logger(**job_config.metrics)
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())
Expand Down Expand Up @@ -238,6 +239,8 @@ def train_step(self, batch) -> None:
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.metric_logger.log("loss", loss, self.current_step)

self.optimizers.step()
self.lr_schedulers.step()

Expand Down
111 changes: 109 additions & 2 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Mapping, Optional

from monarch.actor import Actor, endpoint

from forge.types import Action, Message, Observation, State
from forge.types import Action, Message, Observation, Scalar, State


class Transform(ABC):
Expand Down Expand Up @@ -150,3 +150,110 @@ def tokenize_messages(
tuple[list[int], list[bool]]: The list of token ids and the list of masks.
"""
pass


class MetricLogger(ABC):
"""Abstract metric logger.

Args:
log_freq (Mapping[str, int]):
calls to `log` and `log_dict` will be ignored if `step % log_freq[metric_name] != 0`
"""

def __init__(self, log_freq: Mapping[str, int]):
self._log_freq = log_freq
self._step = None

def set_step(self, step: int) -> None:
"""Subsequent log calls will use this step number by default if not provided to the log call."""
self._step = step

def is_log_step(self, name: str, step: Optional[int] = None):
"""Returns true if the current step is a logging step.

Args:
name (str): metric name (for checking the log freq for this metric)
step (int): current step. if not given, will use the one last provided via set_step()
"""
if step is None:
assert (
self._step is not None
), "`step` arg required if `set_step` has not been called."
step = self._step
return step % self._log_freq[name] == 0

def log(
self,
name: str,
data: Scalar,
step: Optional[int] = None,
) -> None:
"""Log scalar data if this is a logging step.

Args:
name (str): tag name used to group scalars
data (Scalar): scalar data to log
step (int): step value to record. if not given, will use the one last provided via set_step()
"""
if step is None:
assert (
self._step is not None
), "`step` arg required if `set_step` has not been called."
step = self._step
if step % self._log_freq[name] == 0:
self._log(name, data, step)

def log_dict(
self, metrics: Mapping[str, Scalar], step: Optional[int] = None
) -> None:
"""Log multiple scalar values if this is a logging step.

Args:
metrics (Mapping[str, Scalar]): dictionary of tag name and scalar value
step (int): step value to record. if not given, will use the one last provided via set_step()
"""
if step is None:
assert (
self._step is not None
), "`step` arg required if `set_step` has not been called."
step = self._step

log_step_metrics = {
name: value
for name, value in metrics.items()
if step % self._log_freq[name] == 0
}
if log_step_metrics:
self._log_dict(log_step_metrics, step)

@abstractmethod
def _log(self, name: str, data: Scalar, step: int) -> None:
"""Log scalar data.

Args:
name (str): tag name used to group scalars
data (Scalar): scalar data to log
step (int): step value to record
"""
pass

@abstractmethod
def _log_dict(self, payload: Mapping[str, Scalar], step: int) -> None:
"""Log multiple scalar values.

Args:
payload (Mapping[str, Scalar]): dictionary of tag name and scalar value
step (int): step value to record
"""
pass

def __del__(self) -> None:
self.close()

def close(self) -> None:
"""
Close log resource, flushing if necessary.
This will automatically be called via __del__ when the instance goes out of scope.
Logs should not be written after `close` is called.
"""
pass
8 changes: 7 additions & 1 deletion src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Any, Literal, TypedDict
from typing import Any, Literal, TypedDict, Union

from numpy import ndarray
from torch import Tensor


class Message(TypedDict):
Expand Down Expand Up @@ -98,3 +101,6 @@ class ProcessConfig:
oncall: str = "torchtune"
identity: str = "pytorch_distributed"
image: str = "forge_workspace:latest"


Scalar = Union[Tensor, ndarray, int, float]
15 changes: 15 additions & 0 deletions src/forge/util/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .logging import deprecated, get_logger, log_once, log_rank_zero
from .metric_logging import get_metric_logger

__all__ = [
"deprecated",
"get_logger",
"log_once",
"log_rank_zero",
"get_metric_logger",
]
147 changes: 147 additions & 0 deletions src/forge/util/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect
import logging
import warnings
from functools import lru_cache, wraps
from typing import Callable, Optional, TypeVar

from torch import distributed as dist

T = TypeVar("T", bound=type)


def get_logger(level: Optional[str] = None) -> logging.Logger:
"""
Get a logger with a stream handler.
Args:
level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels.
Example:
>>> logger = get_logger("INFO")
>>> logger.info("Hello world!")
INFO:torchtune.utils._logging:Hello world!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
INFO:torchtune.utils._logging:Hello world!
INFO:forge.util.logging:Hello world!

Returns:
logging.Logger: The logger.
"""
logger = logging.getLogger(__name__)
if not logger.hasHandlers():
logger.addHandler(logging.StreamHandler())
if level is not None:
level = getattr(logging, level.upper())
logger.setLevel(level)
return logger


def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None:
"""
Logs a message only on rank zero.
Args:
logger (logging.Logger): The logger.
msg (str): The warning message.
level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values.
Defaults to ``logging.INFO``.
"""
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if rank != 0:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't log rank be configurable like how titan did? But I guess we can always improve later.

logger.log(level, msg, stacklevel=2)


@lru_cache(None)
def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None:
"""
Logs a message only once. LRU cache is used to ensure a specific message is
logged only once, similar to how :func:`~warnings.warn` works when the ``once``
rule is set via command-line or environment variable.
Args:
logger (logging.Logger): The logger.
msg (str): The warning message.
level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values.
Defaults to ``logging.INFO``.
"""
log_rank_zero(logger=logger, msg=msg, level=level)


def deprecated(msg: str = "") -> Callable[[T], T]:
"""
Decorator to mark an object as deprecated and print additional message.
Args:
msg (str): additional information to print after warning.
Returns:
Callable[[T], T]: the decorated object.
"""

@lru_cache(maxsize=1)
def warn(obj):
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if rank != 0:
return
warnings.warn(
f"{obj.__name__} is deprecated and will be removed in future versions. "
+ msg,
category=FutureWarning,
stacklevel=3,
)

def decorator(obj):
@wraps(obj)
def wrapper(*args, **kwargs):
warn(obj)
return obj(*args, **kwargs)

return wrapper

return decorator


def deprecate_parameter(param_name: str, msg: str = "") -> Callable[[T], T]:
"""
Decorator to mark a parameter as deprecated and print additional message.
Args:
param_name (str): The name of the parameter.
msg (str): additional information to print after warning.
Returns:
Callable[[T], T]: the decorated object.
"""

@lru_cache(maxsize=1)
def warn(obj):
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
if rank != 0:
return
warnings.warn(
f"{param_name} is deprecated for {obj.__name__} and will be removed in future versions. "
+ msg,
category=FutureWarning,
stacklevel=3,
)

def decorator(obj):
sig = inspect.signature(obj)

@wraps(obj)
def wrapper(*args, **kwargs):
# Check positional and kwargs
bound_args = sig.bind_partial(*args, **kwargs)
all_args = {**bound_args.arguments}
all_args.update(kwargs)
if param_name in all_args:
warn(obj)
return obj(*args, **kwargs)

return wrapper

return decorator
Loading
Loading