Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 4 additions & 85 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import logging
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable

import torch
import torch.nn.functional as F
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_model import HFReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import compute_logprobs, Episode
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
Expand All @@ -28,21 +29,6 @@
logger.setLevel(logging.DEBUG)


def compute_logprobs(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the implementation here is functionally equivalent to the implementation previously in reference_actor.py (which was moved to trainer.py)

logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]

# Truncate request logits and drop last
logits = logits[:, context_length - 1 : -1]

# Compute logprobs
logprobs = torch.log_softmax(logits / temperature, dim=-1)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)

return logprobs


class SimpleGRPOLoss(nn.Module):
"""Simplified GRPO Loss for simplified single step updates
Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
Expand All @@ -68,41 +54,6 @@ def forward(self, logprobs, ref_logprobs, advantages, padding_mask):
return loss


@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Optional[Any] = None
# processed data
response: Optional[str] = None
request_tokens: Optional[list[int]] = None
response_tokens: Optional[list[int]] = None
ref_logprobs: Optional[torch.Tensor] = None
reward: Optional[float] = None
advantage: Optional[float] = None

@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor

@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor


@dataclass
class Group:
group_id: str
Expand Down Expand Up @@ -250,38 +201,6 @@ async def compute(self, group: Group) -> list[float]:
return x


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

self.model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, episode: Episode) -> torch.Tensor:
req, res = episode.request_tensor, episode.response_tensor
input_ids = torch.cat([req, res]).to(self.device).unsqueeze(0)
mask = input_ids != episode.pad_id

with torch.inference_mode():
logits = self.model(input_ids=input_ids, attention_mask=mask).logits

input_ids = input_ids[:, len(req) :]
return compute_logprobs(logits, input_ids)


@dataclass
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""
Expand Down Expand Up @@ -387,7 +306,7 @@ async def main():
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
HFReferenceModel,
model_name=model,
),
spawn_service(
Expand Down
6 changes: 3 additions & 3 deletions src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def __getattr__(name):
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
from .reference_actor import TitanRefModel
elif name == "ReferenceModel":
from .reference_model import ReferenceModel

return TitanRefModel
return ReferenceModel
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
Loading
Loading