Skip to content

Commit 2e4a8a5

Browse files
committed
Remove last bit
1 parent 1a92113 commit 2e4a8a5

File tree

8 files changed

+15
-283
lines changed

8 files changed

+15
-283
lines changed

src/forge/data/dataset_metrics/metric_transform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from enum import Enum
1010
from typing import Any, Union
1111

12-
from forge.interfaces import Transform
13-
1412

1513
@dataclass(frozen=True)
1614
class Metric:
@@ -35,7 +33,7 @@ class AggregationType(Enum):
3533
MIN = "min"
3634

3735

38-
class MetricTransform(Transform, ABC):
36+
class MetricTransform(ABC):
3937
"""Applied to each dataset sample to generate per-sample metrics for training tracking.
4038
4139
Creates Metric objects that are later aggregated by MetricsAggregator. This separation

src/forge/data/datasets/hf_dataset.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Metric,
1919
MetricTransform,
2020
)
21-
from forge.interfaces import Transform
2221

2322
from .dataset import DatasetInfo, InfiniteTuneIterableDataset
2423

@@ -37,10 +36,10 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
3736
- Returning an infinite iterator over the dataset
3837
3938
Args:
40-
message_transform (Transform | None): Transforms raw data into a `Message`.
41-
model_transform (Transform | None): Prepares messages for the model,
39+
message_transform (Callable | None): Transforms raw data into a `Message`.
40+
model_transform (Callable | None): Prepares messages for the model,
4241
usually by tokenizing them.
43-
output_transform (Transform | None): Prepares tokenized inputs for the
42+
output_transform (Callable | None): Prepares tokenized inputs for the
4443
recipe, often by manipulating labels (e.g., setting an ignore index).
4544
This transform is recipe-dependent (e.g., SFT, DPO, etc.).
4645
metric_transform (MetricTransform | None): Computes metrics from a
@@ -64,9 +63,9 @@ class HfIterableDataset(InfiniteTuneIterableDataset):
6463
def __init__(
6564
self,
6665
*,
67-
message_transform: Transform | None = None,
68-
model_transform: Transform | None = None,
69-
output_transform: Transform | None = None,
66+
message_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
67+
model_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
68+
output_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
7069
metric_transform: MetricTransform | None = None,
7170
shuffle_buffer_size: int | None = 1000,
7271
weight: float | None = 1.0,

src/forge/data/datasets/sft_dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from forge.data import CROSS_ENTROPY_IGNORE_IDX
1212
from forge.data.dataset_metrics import DefaultTrainingMetricTransform
1313
from forge.data.utils import mask_messages, TuneMessage
14-
from forge.interfaces import Transform
1514

1615
from .hf_dataset import HfIterableDataset
1716

1817

19-
class AlpacaToMessages(Transform):
18+
class AlpacaToMessages:
2019
"""
2120
Message transform class for Alpaca-style datasets with "instruction", "input", and "output"
2221
(or equivalent fields specified in column_map) columns. User messages are formed from the
@@ -153,10 +152,10 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
153152

154153

155154
def sft_iterable_dataset(
156-
model_transform: Transform,
155+
model_transform: Callable[[dict[str, Any]], dict[str, Any]],
157156
*,
158157
weight: int = 1,
159-
message_transform: Transform,
158+
message_transform: Callable[[dict[str, Any]], dict[str, Any]],
160159
shuffle_buffer_size: int | None = 1000,
161160
seed: int = 42,
162161
num_shards_per_rank: int = 64,
@@ -169,9 +168,9 @@ def sft_iterable_dataset(
169168
Creates an SFT-ready iterable dataset with appropriate output transform.
170169
171170
Args:
172-
model_transform (Transform): Usually the tokenizer
171+
model_transform (Callable): Usually the tokenizer
173172
weight (int): Weight of the dataset. Used for sampling when interleaving datasets.
174-
message_transform (Transform): Transform to convert raw data to messages
173+
message_transform (Callable): Transform to convert raw data to messages
175174
shuffle_buffer_size (int | None): Buffer size for shuffling
176175
seed (int): Random seed for shuffling
177176
num_shards_per_rank (int): Target shards per worker

src/forge/data/rewards.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66

77
import re
88

9-
from forge.interfaces import Reward
109

11-
12-
class MathReward(Reward):
10+
class MathReward:
1311
"""Reward class for evaluating math correctness."""
1412

1513
def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
@@ -58,7 +56,7 @@ def _to_float(self, text: str) -> float | None:
5856
return None
5957

6058

61-
class ThinkingReward(Reward):
59+
class ThinkingReward:
6260
"""Reward class for evaluating use of <think> tags in reasoning."""
6361

6462
def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):

src/forge/data/sharding.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

src/forge/data_models/episode.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/forge/data_models/scored_completion.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/forge/interfaces.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,7 @@
77
from abc import ABC, abstractmethod
88
from typing import Any, Mapping
99

10-
from forge.types import Message, Observation, Scalar
11-
12-
13-
class Transform(ABC):
14-
"""Abstract base class for observation transforms.
15-
16-
Transforms are first-class citizens that can modify observations,
17-
typically to add rewards, compute metrics, or modify state.
18-
19-
They follow a functional interface where they take an observation
20-
and return a (potentially modified) observation.
21-
"""
22-
23-
@abstractmethod
24-
def __call__(self, observation: Observation) -> Observation:
25-
"""Transform an observation.
26-
27-
Args:
28-
observation: The input observation to transform
29-
30-
Returns:
31-
The transformed observation (may be the same instance if no changes)
32-
"""
33-
pass
10+
from forge.types import Message, Scalar
3411

3512

3613
class BaseTokenizer(ABC):
@@ -139,12 +116,3 @@ def close(self) -> None:
139116
This will automatically be called via __del__ when the instance goes out of scope.
140117
Logs should not be written after `close` is called.
141118
"""
142-
143-
144-
class Reward(ABC):
145-
"""Abstract base class for reward models."""
146-
147-
@abstractmethod
148-
def __call__(self, observation: Observation) -> float:
149-
"""Compute a reward for an observation."""
150-
pass

0 commit comments

Comments
 (0)