Skip to content

Commit 55c32be

Browse files
committed
Updated rewards (just played around a bit)
1 parent e31f815 commit 55c32be

File tree

1 file changed

+52
-47
lines changed

1 file changed

+52
-47
lines changed

src/forge/data/rewards.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,62 +17,67 @@ def __init__(self, tolerance: float = 1e-6, partial_credit: float = 0.1):
1717
self.tolerance = tolerance
1818
self.partial_credit = partial_credit
1919

20-
def _to_float(self, text) -> Optional[float]:
21-
"""Safely parse a string into a float, or return None if invalid."""
22-
if text is None:
23-
return None
24-
try:
25-
return float(str(text).strip())
26-
except (ValueError, TypeError):
27-
return None
28-
29-
def _extract_number(self, text: str) -> Optional[float]:
30-
"""Try to extract a numeric answer from text."""
31-
number_pattern = r"([+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)"
32-
patterns = [
33-
r"####\s*" + number_pattern,
34-
r"(?:the\s+)?answer\s+is\s*" + number_pattern,
35-
r"(?:answer:|result:)\s*" + number_pattern,
36-
r"\$" + number_pattern, # currency
37-
number_pattern, # fallback
38-
r"=\s*" + number_pattern + r"\s*(?:\.|$)",
39-
r"\b" + number_pattern + r"\s*(?:\.|$)",
40-
]
41-
text = text.lower().strip()
42-
for pattern in patterns:
43-
matches = re.findall(pattern, text)
44-
if matches:
45-
return self._to_float(matches[-1])
46-
return None
47-
4820
def __call__(self, prompt: str, response: str, target: str) -> float:
4921
"""Compute math correctness reward."""
50-
# Parse expected
51-
expected_answer = self._to_float(target)
22+
target_number = self._to_float(target)
23+
if target_number is None:
24+
return 0.0
5225

53-
# Parse response
54-
model_answer = self._extract_number(response)
26+
# Look for answer in <answer></answer> tags
27+
answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
5528

56-
# Scoring
57-
if expected_answer is None or model_answer is None:
58-
return self.partial_credit # Partial credit for attempting
29+
if answer_match:
30+
model_answer = self._to_float(answer_match.group(1).strip())
31+
if (
32+
model_answer is not None
33+
and abs(target_number - model_answer) < self.tolerance
34+
):
35+
return 1.0 # Correct answer
5936

60-
if abs(expected_answer - model_answer) < self.tolerance:
61-
return 1.0 # Correct answer
62-
return 0.0 # Incorrect answer
37+
# Check for partial credit: target number appears elsewhere in response
38+
response_without_answer_tags = re.sub(
39+
r"<answer>.*?</answer>", "", response, flags=re.DOTALL
40+
)
41+
# Convert to int if it's a whole number to avoid "117.0" vs "117" mismatch
42+
target_str = (
43+
str(int(target_number))
44+
if target_number.is_integer()
45+
else str(target_number)
46+
)
47+
if target_str in response_without_answer_tags:
48+
return self.partial_credit
49+
50+
return 0.0 # No match
51+
52+
def _to_float(self, text: str) -> float | None:
53+
"""Convert text to float, return None if invalid."""
54+
try:
55+
# Remove common non-numeric characters like $, commas, etc.
56+
cleaned_text = re.sub(r"[$,\s]", "", text.strip())
57+
return float(cleaned_text)
58+
except (ValueError, AttributeError):
59+
return None
6360

6461

6562
class ThinkingReward(Reward):
6663
"""Reward class for evaluating use of <think> tags in reasoning."""
6764

68-
def __init__(self, reward_value: float = 0.5):
69-
self.reward_value = reward_value
65+
def __init__(self, partial_reward: float = 0.2, full_reward: float = 1.0):
66+
self.partial_reward = partial_reward
67+
self.full_reward = full_reward
68+
self._THINK_BLOCK_RE = re.compile(
69+
r"<\s*think\s*>(.*?)<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL
70+
)
71+
self._THINK_TAG_ATTEMPT_RE = re.compile(r"<\s*/?\s*think\s*>", re.IGNORECASE)
7072

71-
def __call__(
72-
self, prompt: str, response: str, target: Optional[str] = None
73-
) -> float:
74-
"""Check if response contains <think>...</think> tags."""
75-
resp = response.lower()
76-
if "<think>" in resp and "</think>" in resp:
77-
return self.reward_value
73+
def __call__(self, prompt: str, response: str, target: str | None = None) -> float:
74+
matches = self._THINK_BLOCK_RE.findall(response or "")
75+
has_well_formed = any(len(re.sub(r"\s+", "", m)) >= 1 for m in matches)
76+
has_attempt = bool(self._THINK_TAG_ATTEMPT_RE.search(response or "")) or bool(
77+
matches
78+
)
79+
if has_well_formed:
80+
return self.full_reward
81+
elif has_attempt:
82+
return self.partial_reward
7883
return 0.0

0 commit comments

Comments
 (0)