Skip to content
Merged
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller import ServiceConfig, spawn_service
from forge.controller.actor import ForgeActor
from forge.data.rewards import MathReward, ThinkingReward
from monarch.actor import endpoint
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -260,15 +262,15 @@ def thinking_scoring_function(prompt: str, response: str, target: str) -> float:
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""

def __init__(self, scoring_functions: list[Callable]):
def __init__(self, reward_functions: list[Callable]):
super().__init__()
self.scoring_functions = scoring_functions
self.reward_functions = reward_functions

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
total_reward = 0.0
for scoring_fn in self.scoring_functions:
reward = scoring_fn(prompt, response, target)
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_reward += reward
return total_reward

Expand Down Expand Up @@ -447,7 +449,7 @@ async def main():
reward_actor = await spawn_service(
default_service_cfg,
RewardActor,
scoring_functions=[math_scoring_function, thinking_scoring_function],
reward_functions=[MathReward(), ThinkingReward()],
)

print("All services initialized successfully!")
Expand Down
Loading