Skip to content

Commit fac7d02

Browse files
committed
integrate reward functions
1 parent fd1d38b commit fac7d02

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

apps/grpo/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import torch
1313
from datasets import load_dataset
1414
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
15+
from forge.actors.replay_buffer import ReplayBuffer
1516
from forge.controller import ServiceConfig, spawn_service
1617
from forge.controller.actor import ForgeActor
18+
from forge.data.rewards import MathReward, ThinkingReward
1719
from monarch.actor import endpoint
1820
from transformers import AutoModelForCausalLM, AutoTokenizer
1921

@@ -260,15 +262,15 @@ def thinking_scoring_function(prompt: str, response: str, target: str) -> float:
260262
class RewardActor(ForgeActor):
261263
"""Reward actor that uses a list of scoring functions."""
262264

263-
def __init__(self, scoring_functions: list[Callable]):
265+
def __init__(self, reward_functions: list[Callable]):
264266
super().__init__()
265-
self.scoring_functions = scoring_functions
267+
self.reward_functions = reward_functions
266268

267269
@endpoint
268270
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
269271
total_reward = 0.0
270-
for scoring_fn in self.scoring_functions:
271-
reward = scoring_fn(prompt, response, target)
272+
for reward_fn in self.reward_functions:
273+
reward = reward_fn(prompt, response, target)
272274
total_reward += reward
273275
return total_reward
274276

@@ -447,7 +449,7 @@ async def main():
447449
reward_actor = await spawn_service(
448450
default_service_cfg,
449451
RewardActor,
450-
scoring_functions=[math_scoring_function, thinking_scoring_function],
452+
reward_functions=[MathReward(), ThinkingReward()],
451453
)
452454

453455
print("All services initialized successfully!")

0 commit comments

Comments
 (0)