Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 2 additions & 71 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

import asyncio
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
from typing import Callable

import torch
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors.policy import Policy
Expand All @@ -25,6 +23,7 @@
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.types import Episode, Group
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand All @@ -49,74 +48,6 @@ def compute_logprobs(
return logprobs


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Any | None = None
# processed data
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None

@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor

@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor


@dataclass
class Group:
group_id: str
episodes: list[Episode]

@classmethod
def new_group(
cls,
group_id: int,
group_size: int,
request: str,
policy_version: int,
pad_id: int,
request_len: int,
response_len: int,
target: Any = None,
):
episodes = []
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
request=request,
policy_version=policy_version,
pad_id=pad_id,
request_len=request_len,
response_len=response_len,
target=target,
)
)
return cls(str(group_id), episodes)


@dataclass
class Trainer(ForgeActor):
"""GRPO Trainer implementation for policy optimization."""
Expand Down
Loading
Loading