diff --git a/src/forge/data/dataset_metrics/metric_transform.py b/src/forge/data/dataset_metrics/metric_transform.py index 2898c8e43..a9af39ade 100644 --- a/src/forge/data/dataset_metrics/metric_transform.py +++ b/src/forge/data/dataset_metrics/metric_transform.py @@ -9,8 +9,6 @@ from enum import Enum from typing import Any, Union -from forge.interfaces import Transform - @dataclass(frozen=True) class Metric: @@ -35,7 +33,7 @@ class AggregationType(Enum): MIN = "min" -class MetricTransform(Transform, ABC): +class MetricTransform(ABC): """Applied to each dataset sample to generate per-sample metrics for training tracking. Creates Metric objects that are later aggregated by MetricsAggregator. This separation diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 6be68b41b..799dd89b9 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -18,7 +18,6 @@ Metric, MetricTransform, ) -from forge.interfaces import Transform from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -37,10 +36,10 @@ class HfIterableDataset(InfiniteTuneIterableDataset): - Returning an infinite iterator over the dataset Args: - message_transform (Transform | None): Transforms raw data into a `Message`. - model_transform (Transform | None): Prepares messages for the model, + message_transform (Callable | None): Transforms raw data into a `Message`. + model_transform (Callable | None): Prepares messages for the model, usually by tokenizing them. - output_transform (Transform | None): Prepares tokenized inputs for the + output_transform (Callable | None): Prepares tokenized inputs for the recipe, often by manipulating labels (e.g., setting an ignore index). This transform is recipe-dependent (e.g., SFT, DPO, etc.). metric_transform (MetricTransform | None): Computes metrics from a @@ -64,9 +63,9 @@ class HfIterableDataset(InfiniteTuneIterableDataset): def __init__( self, *, - message_transform: Transform | None = None, - model_transform: Transform | None = None, - output_transform: Transform | None = None, + message_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + model_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + output_transform: Callable[[dict[str, Any]], dict[str, Any]] | None = None, metric_transform: MetricTransform | None = None, shuffle_buffer_size: int | None = 1000, weight: float | None = 1.0, diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index b31d16fad..3a2574643 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -11,12 +11,11 @@ from forge.data import CROSS_ENTROPY_IGNORE_IDX from forge.data.dataset_metrics import DefaultTrainingMetricTransform from forge.data.utils import mask_messages, TuneMessage -from forge.interfaces import Transform from .hf_dataset import HfIterableDataset -class AlpacaToMessages(Transform): +class AlpacaToMessages: """ Message transform class for Alpaca-style datasets with "instruction", "input", and "output" (or equivalent fields specified in column_map) columns. User messages are formed from the @@ -153,10 +152,10 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: def sft_iterable_dataset( - model_transform: Transform, + model_transform: Callable[[dict[str, Any]], dict[str, Any]], *, weight: int = 1, - message_transform: Transform, + message_transform: Callable[[dict[str, Any]], dict[str, Any]], shuffle_buffer_size: int | None = 1000, seed: int = 42, num_shards_per_rank: int = 64, @@ -169,9 +168,9 @@ def sft_iterable_dataset( Creates an SFT-ready iterable dataset with appropriate output transform. Args: - model_transform (Transform): Usually the tokenizer + model_transform (Callable): Usually the tokenizer weight (int): Weight of the dataset. Used for sampling when interleaving datasets. - message_transform (Transform): Transform to convert raw data to messages + message_transform (Callable): Transform to convert raw data to messages shuffle_buffer_size (int | None): Buffer size for shuffling seed (int): Random seed for shuffling num_shards_per_rank (int): Target shards per worker diff --git a/src/forge/data/rewards.py b/src/forge/data/rewards.py index 29a86fc3a..23a0002df 100644 --- a/src/forge/data/rewards.py +++ b/src/forge/data/rewards.py @@ -6,10 +6,8 @@ import re -from forge.interfaces import Reward - -class MathReward(Reward): +class MathReward: """Reward class for evaluating math correctness.""" def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1): @@ -58,7 +56,7 @@ def _to_float(self, text: str) -> float | None: return None -class ThinkingReward(Reward): +class ThinkingReward: """Reward class for evaluating use of tags in reasoning.""" def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0): diff --git a/src/forge/data/sharding.py b/src/forge/data/sharding.py deleted file mode 100644 index 2027f8a43..000000000 --- a/src/forge/data/sharding.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch - - -class VLLMSharding: - """ - vLLM specific tensor parallel sharding strategy. - """ - - def __init__(self, tensor_parallel_size: int, rank: int): - self.tensor_parallel_size = tensor_parallel_size - self.rank = rank - - def load_from_source_to_target( - self, - param_name: str, - source_tensor: torch.Tensor, - target_tensor: torch.Tensor, - ) -> None: - """ - Copy a source tensor to a target tensor, handling sharding and replication. - """ - # Determine sharding strategy for this parameter - shard_dim, is_sharded = self._get_tensor_parallel_sharding_strategy(param_name) - - if not is_sharded: - # Parameter is replicated - shapes should match exactly - if source_tensor.shape != target_tensor.shape: - raise ValueError( - f"Replicated parameter {param_name} has mismatched shapes: " - f"{source_tensor.shape} vs {target_tensor.shape}, skipping" - ) - - # Direct copy for replicated parameters - target_tensor.copy_(source_tensor) - else: - # Need to shard the full tensor - sharded_tensor = self._calculate_tensor_shard( - source_tensor, shard_dim, self.tensor_parallel_size, self.rank - ) - - if sharded_tensor.shape != target_tensor.shape: - raise ValueError( - f"Calculated shard for {param_name} has wrong shape: " - f"{sharded_tensor.shape} vs expected {target_tensor.shape}, skipping" - ) - - target_tensor.copy_(sharded_tensor) - - def _get_tensor_parallel_sharding_strategy( - self, param_name: str - ) -> tuple[int, bool]: - """ - Determine the sharding strategy for a parameter in tensor parallel setup. - - Returns: - tuple[int, bool]: (shard_dimension, is_sharded) - - shard_dimension: Which dimension to shard (0 or 1) - - is_sharded: Whether this parameter should be sharded at all - - Based on vLLM's tensor parallel implementation for LLaMA models: - - Embedding layers: shard along vocab dimension (dim 0) - - Attention projections: qkv_proj shard along hidden dimension (dim 0), o_proj along input dimension (dim 1) - - MLP projections: gate/up_proj shard along hidden dimension (dim 0), down_proj along input dimension (dim 1) - - Layer norms: not sharded (replicated) - - Output layer: shard along vocab dimension (dim 0) - """ - # Parameters that are not sharded (replicated across all tensor parallel ranks) - if any(keyword in param_name for keyword in ["norm", "bias", "rotary_emb"]): - return 0, False - - # Embedding layers - shard along vocab dimension (dim 0) - if "embed_tokens" in param_name or "lm_head" in param_name: - return 0, True - - # Attention projections - if "qkv_proj" in param_name: - # Input projections: shard output dimension (dim 0) - return 0, True - elif "o_proj" in param_name: - # Output projection: shard input dimension (dim 1) - return 1, True - - # MLP projections - elif any( - proj in param_name for proj in ["gate_proj", "up_proj", "gate_up_proj"] - ): - # Input projections: shard output dimension (dim 0) - return 0, True - elif "down_proj" in param_name: - # Output projection: shard input dimension (dim 1) - return 1, True - - # Default: try to infer from tensor shape patterns - return 0, True - - def _calculate_tensor_shard( - self, - full_tensor: torch.Tensor, - shard_dim: int, - tensor_parallel_size: int, - rank: int, - ) -> torch.Tensor: - """ - Calculate the shard of a full tensor for the current tensor parallel rank. - - Args: - full_tensor: The full tensor to shard - shard_dim: Which dimension to shard along (0 or 1) - tensor_parallel_size: Number of tensor parallel ranks - rank: Current rank (will be modulo by tensor_parallel_size) - - Returns: - torch.Tensor: The sharded tensor for this rank - """ - tp_rank = rank % tensor_parallel_size - tensor_size = full_tensor.shape[shard_dim] - - if tensor_size % tensor_parallel_size != 0: - raise ValueError( - f"Cannot shard tensor dimension {shard_dim} with size {tensor_size} " - f"across {tensor_parallel_size} ranks: not evenly divisible" - ) - - shard_size = tensor_size // tensor_parallel_size - start_idx = tp_rank * shard_size - end_idx = start_idx + shard_size - - # Create index tensor for the shard range - indices = torch.arange(start_idx, end_idx, device=full_tensor.device) - - if shard_dim == 0: - return torch.index_select(full_tensor, 0, indices) - elif shard_dim == 1: - return torch.index_select(full_tensor, 1, indices) - else: - raise ValueError(f"Unsupported shard dimension: {shard_dim}") diff --git a/src/forge/data_models/episode.py b/src/forge/data_models/episode.py deleted file mode 100644 index 6b908ff87..000000000 --- a/src/forge/data_models/episode.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from typing import Sequence - -import torch - -from forge.data_models.scored_completion import ScoredCompletion - - -@dataclass -class Episode: - """ - The Episode data class to be used by the trainer. - - Episodes are usually generated from a scored completion and running various post processing steps. - """ - - # Concatenated prompt and sample token ids. - ids: torch.Tensor - - # The mask for the target ids, 0 for prompt tokens, 1 for sample tokens. - mask: torch.Tensor - - # The weight to apply to the loss of each target token. It's normally computed - # from the advantage and the reward. - weights: torch.Tensor - - # The log probabilities of the target tokens, for prompt part it's set to 0, - # for generation part it's computed from the Generator/Sampler. - log_probs: torch.Tensor | None = None - - # TODO: add more fields as required - state: str = "" - - -def from_scored_completion(scored_completion: ScoredCompletion) -> Episode: - """Converts a ScoredCompletion to an Episode.""" - prompt_ids = scored_completion.completion.prompt_ids - token_ids = scored_completion.completion.token_ids - log_probs = scored_completion.completion.log_probs - ids = torch.cat([prompt_ids, token_ids]) - mask = torch.cat( - [ - torch.zeros(prompt_ids.shape, dtype=torch.float32), - torch.ones_like(token_ids, dtype=torch.float32), - ] - ) - advantage = scored_completion.score - weights = mask * advantage - log_probs = torch.cat( - [ - torch.zeros(prompt_ids.shape, dtype=torch.float32), - # TODO: this only works if sample.log_probs is 1 - log_probs, - ] - ) - return Episode(ids=ids, mask=mask, weights=weights, log_probs=log_probs) - - -def from_scored_completions( - scored_completions: Sequence[ScoredCompletion], -) -> Sequence[Episode]: - """Converts a sequence of ScoredCompletion to a sequence of Episodes.""" - return [from_scored_completion(sc) for sc in scored_completions] diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py deleted file mode 100644 index f41ff7b59..000000000 --- a/src/forge/data_models/scored_completion.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass - -from forge.data_models.completion import Completion - - -@dataclass -class ScoredCompletion: - """A completion with an associated score (from a reward model or human).""" - - completion: Completion - score: float # akin to reward - - # TODO: add more fields as needed. diff --git a/src/forge/interfaces.py b/src/forge/interfaces.py index 526e36c56..8a4ca06ef 100644 --- a/src/forge/interfaces.py +++ b/src/forge/interfaces.py @@ -7,30 +7,7 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from forge.types import Message, Observation, Scalar - - -class Transform(ABC): - """Abstract base class for observation transforms. - - Transforms are first-class citizens that can modify observations, - typically to add rewards, compute metrics, or modify state. - - They follow a functional interface where they take an observation - and return a (potentially modified) observation. - """ - - @abstractmethod - def __call__(self, observation: Observation) -> Observation: - """Transform an observation. - - Args: - observation: The input observation to transform - - Returns: - The transformed observation (may be the same instance if no changes) - """ - pass +from forge.types import Message, Scalar class BaseTokenizer(ABC): @@ -139,12 +116,3 @@ def close(self) -> None: This will automatically be called via __del__ when the instance goes out of scope. Logs should not be written after `close` is called. """ - - -class Reward(ABC): - """Abstract base class for reward models.""" - - @abstractmethod - def __call__(self, observation: Observation) -> float: - """Compute a reward for an observation.""" - pass