|
12 | 12 | import torch |
13 | 13 | from datasets import load_dataset |
14 | 14 | from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig |
| 15 | +from forge.actors.replay_buffer import ReplayBuffer |
15 | 16 | from forge.controller import ServiceConfig, spawn_service |
16 | 17 | from forge.controller.actor import ForgeActor |
| 18 | +from forge.data.rewards import MathReward, ThinkingReward |
17 | 19 | from monarch.actor import endpoint |
18 | 20 | from transformers import AutoModelForCausalLM, AutoTokenizer |
19 | 21 |
|
@@ -260,15 +262,15 @@ def thinking_scoring_function(prompt: str, response: str, target: str) -> float: |
260 | 262 | class RewardActor(ForgeActor): |
261 | 263 | """Reward actor that uses a list of scoring functions.""" |
262 | 264 |
|
263 | | - def __init__(self, scoring_functions: list[Callable]): |
| 265 | + def __init__(self, reward_functions: list[Callable]): |
264 | 266 | super().__init__() |
265 | | - self.scoring_functions = scoring_functions |
| 267 | + self.reward_functions = reward_functions |
266 | 268 |
|
267 | 269 | @endpoint |
268 | 270 | async def evaluate_response(self, prompt: str, response: str, target: str) -> float: |
269 | 271 | total_reward = 0.0 |
270 | | - for scoring_fn in self.scoring_functions: |
271 | | - reward = scoring_fn(prompt, response, target) |
| 272 | + for reward_fn in self.reward_functions: |
| 273 | + reward = reward_fn(prompt, response, target) |
272 | 274 | total_reward += reward |
273 | 275 | return total_reward |
274 | 276 |
|
@@ -447,7 +449,7 @@ async def main(): |
447 | 449 | reward_actor = await spawn_service( |
448 | 450 | default_service_cfg, |
449 | 451 | RewardActor, |
450 | | - scoring_functions=[math_scoring_function, thinking_scoring_function], |
| 452 | + reward_functions=[MathReward(), ThinkingReward()], |
451 | 453 | ) |
452 | 454 |
|
453 | 455 | print("All services initialized successfully!") |
|
0 commit comments