diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 8d1ec1e79f..95c45ee9b0 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -7,10 +7,15 @@ import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import torch +from torchtune.modules.transforms.tokenizers import ( + HuggingFaceModelTokenizer, + ModelTokenizer, +) + @dataclass class RewardOutput: @@ -216,3 +221,96 @@ def __call__( }, successes=successes, ) + + +def at_least_one_space_between_think_tags( + cot: str, answer: str, potential_answer: str +) -> tuple[float, float]: + """Did the model at least try to think?""" + if len(cot) > 0: + return 1.0, 1.0 # (reward, success) + else: + return 0.0, 0.0 + + +def math_response_correct( + cot: str, answer: str, potential_answer: str +) -> tuple[float, float]: + """Did it get the right answer?""" + import math_verify + + if potential_answer is None: + return 0.0, 0.0 # (reward, success) + gold = math_verify.parse(answer) + attempt = math_verify.parse(potential_answer) + + if math_verify.verify(gold, attempt): + return 100.0, 1.0 + if answer in potential_answer: + return 50.0, 0.0 + if len(potential_answer) > 0: + return 1.0, 0.0 + return 0.0, 0.0 + + +def extract_tags(text: str) -> tuple[str, str]: + """ + Parse XML-like tags from text. Returns a dictionary with keys 'think' and 'answer'. + The values are lists of strings, with each string being the content of a tag. + """ + think_pattern = r"(.*?)" + answer_pattern = r"(.*?)" + think_match = re.search(think_pattern, text, re.DOTALL) + answer_match = re.search(answer_pattern, text, re.DOTALL) + cot = think_match.group(1).strip() if think_match else "" + potential_answer = answer_match.group(1).strip() if answer_match else "" + return cot, potential_answer + + +def batched_rewards( + tokenizer: Union[ModelTokenizer, HuggingFaceModelTokenizer], + completions: torch.Tensor, + answers: list[str], + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, dict]: + + reward_funcs = [ + at_least_one_space_between_think_tags, + math_response_correct, + ] + + num_reward_funcs = len(reward_funcs) + + batch_size, grpo_size, _ = completions.shape + + # TODO: should this be bfloat16? + + rewards_tensor = torch.zeros( + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device + ) + + successes_tensor = torch.zeros( + batch_size, grpo_size, num_reward_funcs, dtype=torch.float32, device=device + ) + + metadata = {"func_names": [f.__name__ for f in reward_funcs]} + + for b in range(batch_size): + + for g in range(grpo_size): + + answer = answers[b][g] + + text_completion = tokenizer.decode(completions[b, g].tolist()) + + cot, potential_answer = extract_tags(f"{text_completion}") + + for rw_idx, reward_func in enumerate(reward_funcs): + + reward, success = reward_func(cot, answer, potential_answer) + + rewards_tensor[b, g, rw_idx] += reward + + successes_tensor[b, g, rw_idx] += success + + return rewards_tensor, successes_tensor, metadata