diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py
index 8d1ec1e79f..95c45ee9b0 100644
--- a/torchtune/dev/rl/rewards.py
+++ b/torchtune/dev/rl/rewards.py
@@ -7,10 +7,15 @@
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
-from typing import Optional
+from typing import Optional, Union
import torch
+from torchtune.modules.transforms.tokenizers import (
+ HuggingFaceModelTokenizer,
+ ModelTokenizer,
+)
+
@dataclass
class RewardOutput:
@@ -216,3 +221,96 @@ def __call__(
},
successes=successes,
)
+
+
+def at_least_one_space_between_think_tags(
+ cot: str, answer: str, potential_answer: str
+) -> tuple[float, float]:
+ """Did the model at least try to think?"""
+ if len(cot) > 0:
+ return 1.0, 1.0 # (reward, success)
+ else:
+ return 0.0, 0.0
+
+
+def math_response_correct(
+ cot: str, answer: str, potential_answer: str
+) -> tuple[float, float]:
+ """Did it get the right answer?"""
+ import math_verify
+
+ if potential_answer is None:
+ return 0.0, 0.0 # (reward, success)
+ gold = math_verify.parse(answer)
+ attempt = math_verify.parse(potential_answer)
+
+ if math_verify.verify(gold, attempt):
+ return 100.0, 1.0
+ if answer in potential_answer:
+ return 50.0, 0.0
+ if len(potential_answer) > 0:
+ return 1.0, 0.0
+ return 0.0, 0.0
+
+
+def extract_tags(text: str) -> tuple[str, str]:
+ """
+ Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'.
+ The values are lists of strings, with each string being the content of a tag.
+ """
+ think_pattern = r"(.*?)"
+ answer_pattern = r"(.*?)"
+ think_match = re.search(think_pattern, text, re.DOTALL)
+ answer_match = re.search(answer_pattern, text, re.DOTALL)
+ cot = think_match.group(1).strip() if think_match else ""
+ potential_answer = answer_match.group(1).strip() if answer_match else ""
+ return cot, potential_answer
+
+
+def batched_rewards(
+ tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer],
+ completions: torch.Tensor,
+ answers: list[str],
+ device: torch.device,
+) -> tuple[torch.Tensor, torch.Tensor, dict]:
+
+ reward_funcs = [
+ at_least_one_space_between_think_tags,
+ math_response_correct,
+ ]
+
+ num_reward_funcs = len(reward_funcs)
+
+ batch_size, grpo_size, _ = completions.shape
+
+ # TODO: should this be bfloat16?
+
+ rewards_tensor = torch.zeros(
+ batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
+ )
+
+ successes_tensor = torch.zeros(
+ batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device
+ )
+
+ metadata = {"func_names": [f.__name__ for f in reward_funcs]}
+
+ for b in range(batch_size):
+
+ for g in range(grpo_size):
+
+ answer = answers[b][g]
+
+ text_completion = tokenizer.decode(completions[b, g].tolist())
+
+ cot, potential_answer = extract_tags(f"{text_completion}")
+
+ for rw_idx, reward_func in enumerate(reward_funcs):
+
+ reward, success = reward_func(cot, answer, potential_answer)
+
+ rewards_tensor[b, g, rw_idx] += reward
+
+ successes_tensor[b, g, rw_idx] += success
+
+ return rewards_tensor, successes_tensor, metadata