Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 6 additions & 53 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller import ServiceConfig, spawn_service
from forge.controller.actor import ForgeActor
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -211,66 +212,18 @@ async def update_weights(self, policy_actor):
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")


def math_scoring_function(prompt: str, response: str, target: str) -> float:
"""Function to score math correctness."""
import re

# Extract expected answer from target
expected_answer = (
float(target.strip())
if target.strip().replace(".", "").replace("-", "").isdigit()
else None
)

# Extract model answer from response
patterns = [
r"####\s*([+-]?\d+(?:\.\d+)?)", # GSM8K style answer format
r"(?:the\s+)?answer\s+is\s*([+-]?\d+(?:\.\d+)?)",
r"(?:answer:|result:)\s*([+-]?\d+(?:\.\d+)?)",
r"=\s*([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # equals near end
r"\b([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # number at end
r"([+-]?\d+(?:\.\d+)?)", # any number (fallback)
]

model_answer = None
response_lower = response.lower().strip()
for pattern in patterns:
matches = re.findall(pattern, response_lower)
if matches:
model_answer = float(matches[-1])
break

if expected_answer is None or model_answer is None:
return 0.1 # Partial credit for attempting

# Check if answers match (with some tolerance for floating point)
if abs(expected_answer - model_answer) < 1e-6:
return 1.0 # Correct answer
else:
return 0.0 # Incorrect answer


def thinking_scoring_function(prompt: str, response: str, target: str) -> float:
"""Function to score thinking tag usage."""
# Check if response contains <think></think> tags
if "<think>" in response.lower() and "</think>" in response.lower():
return 0.5
else:
return 0.0


class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""

def __init__(self, scoring_functions: list[Callable]):
def __init__(self, reward_functions: list[Callable]):
super().__init__()
self.scoring_functions = scoring_functions
self.reward_functions = reward_functions

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
total_reward = 0.0
for scoring_fn in self.scoring_functions:
reward = scoring_fn(prompt, response, target)
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_reward += reward
return total_reward

Expand Down Expand Up @@ -456,7 +409,7 @@ async def main():
reward_actor = await spawn_service(
default_service_cfg,
RewardActor,
scoring_functions=[math_scoring_function, thinking_scoring_function],
reward_functions=[MathReward(), ThinkingReward()],
)

print("All services initialized successfully!")
Expand Down
Loading