Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/forge/data/dataset_metrics/metric_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from enum import Enum
from typing import Any, Union

from forge.interfaces import Transform


@dataclass(frozen=True)
class Metric:
Expand All @@ -35,7 +33,7 @@ class AggregationType(Enum):
MIN = "min"


class MetricTransform(Transform, ABC):
class MetricTransform(ABC):
"""Applied to each dataset sample to generate per-sample metrics for training tracking.

Creates Metric objects that are later aggregated by MetricsAggregator. This separation
Expand Down
13 changes: 6 additions & 7 deletions src/forge/data/datasets/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Metric,
MetricTransform,
)
from forge.interfaces import Transform

from .dataset import DatasetInfo, InfiniteTuneIterableDataset

Expand All @@ -37,10 +36,10 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
- Returning an infinite iterator over the dataset

Args:
message_transform (Transform | None): Transforms raw data into a `Message`.
model_transform (Transform | None): Prepares messages for the model,
message_transform (Callable | None): Transforms raw data into a `Message`.
model_transform (Callable | None): Prepares messages for the model,
usually by tokenizing them.
output_transform (Transform | None): Prepares tokenized inputs for the
output_transform (Callable | None): Prepares tokenized inputs for the
recipe, often by manipulating labels (e.g., setting an ignore index).
This transform is recipe-dependent (e.g., SFT, DPO, etc.).
metric_transform (MetricTransform | None): Computes metrics from a
Expand All @@ -64,9 +63,9 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
def __init__(
self,
*,
message_transform: Transform | None = None,
model_transform: Transform | None = None,
output_transform: Transform | None = None,
message_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
model_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
output_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
metric_transform: MetricTransform | None = None,
shuffle_buffer_size: int | None = 1000,
weight: float | None = 1.0,
Expand Down
11 changes: 5 additions & 6 deletions src/forge/data/datasets/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
from forge.data import CROSS_ENTROPY_IGNORE_IDX
from forge.data.dataset_metrics import DefaultTrainingMetricTransform
from forge.data.utils import mask_messages, TuneMessage
from forge.interfaces import Transform

from .hf_dataset import HfIterableDataset


class AlpacaToMessages(Transform):
class AlpacaToMessages:
"""
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
(or equivalent fields specified in column_map) columns. User messages are formed from the
Expand Down Expand Up @@ -153,10 +152,10 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:


def sft_iterable_dataset(
model_transform: Transform,
model_transform: Callable[[dict[str, Any]], dict[str, Any]],
*,
weight: int = 1,
message_transform: Transform,
message_transform: Callable[[dict[str, Any]], dict[str, Any]],
shuffle_buffer_size: int | None = 1000,
seed: int = 42,
num_shards_per_rank: int = 64,
Expand All @@ -169,9 +168,9 @@ def sft_iterable_dataset(
Creates an SFT-ready iterable dataset with appropriate output transform.

Args:
model_transform (Transform): Usually the tokenizer
model_transform (Callable): Usually the tokenizer
weight (int): Weight of the dataset. Used for sampling when interleaving datasets.
message_transform (Transform): Transform to convert raw data to messages
message_transform (Callable): Transform to convert raw data to messages
shuffle_buffer_size (int | None): Buffer size for shuffling
seed (int): Random seed for shuffling
num_shards_per_rank (int): Target shards per worker
Expand Down
6 changes: 2 additions & 4 deletions src/forge/data/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@

import re

from forge.interfaces import Reward


class MathReward(Reward):
class MathReward:
"""Reward class for evaluating math correctness."""

def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
Expand Down Expand Up @@ -58,7 +56,7 @@ def _to_float(self, text: str) -> float | None:
return None


class ThinkingReward(Reward):
class ThinkingReward:
"""Reward class for evaluating use of <think> tags in reasoning."""

def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
Expand Down
142 changes: 0 additions & 142 deletions src/forge/data/sharding.py

This file was deleted.

69 changes: 0 additions & 69 deletions src/forge/data_models/episode.py

This file was deleted.

19 changes: 0 additions & 19 deletions src/forge/data_models/scored_completion.py

This file was deleted.

34 changes: 1 addition & 33 deletions src/forge/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,7 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping

from forge.types import Message, Observation, Scalar


class Transform(ABC):
"""Abstract base class for observation transforms.

Transforms are first-class citizens that can modify observations,
typically to add rewards, compute metrics, or modify state.

They follow a functional interface where they take an observation
and return a (potentially modified) observation.
"""

@abstractmethod
def __call__(self, observation: Observation) -> Observation:
"""Transform an observation.

Args:
observation: The input observation to transform

Returns:
The transformed observation (may be the same instance if no changes)
"""
pass
from forge.types import Message, Scalar


class BaseTokenizer(ABC):
Expand Down Expand Up @@ -139,12 +116,3 @@ def close(self) -> None:
This will automatically be called via __del__ when the instance goes out of scope.
Logs should not be written after `close` is called.
"""


class Reward(ABC):
"""Abstract base class for reward models."""

@abstractmethod
def __call__(self, observation: Observation) -> float:
"""Compute a reward for an observation."""
pass
Loading