Skip to content

Commit 00bbffa

Browse files
committed
Need to test math
1 parent 53607fd commit 00bbffa

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

apps/vllm/judge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import os
1515

16-
from forge.actors.judge import Judge
16+
from forge.actors.judge import EvaluationMode, Judge
1717
from forge.cli.config import parse
1818
from forge.controller.provisioner import shutdown
1919

@@ -50,7 +50,7 @@ async def run(cfg: DictConfig):
5050
print("\nGeneration Results:")
5151
print("=" * 80)
5252
for batch, (best, fact) in enumerate(
53-
zip(best_response_evaluations, fact_check_evaluations)
53+
zip(best_response_evaluations, response_check_evaluations)
5454
):
5555
print(f"Sample {batch + 1}")
5656
print(f"Evaluation (BEST_RESPONSE): {best}")

src/forge/actors/judge.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class EvaluationMode(Enum):
1818

1919
BEST_RESPONSE = auto()
2020
RESPONSE_CHECK = auto()
21+
MATH_CHECK = auto()
2122

2223

2324
@dataclass
@@ -28,7 +29,65 @@ class Judge(Policy):
2829
and may require more postprocessing
2930
"""
3031

31-
def _response_check(self, prompt: str, responses: list[str]) -> str:
32+
def _math_check(
33+
self,
34+
prompt: str,
35+
responses: list[str],
36+
ground_truth: None | str = None,
37+
) -> str:
38+
"""
39+
Construct the generator input. Formats the request such that the generator
40+
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation
41+
for each response, corresponding to whether the model thinks the response
42+
matches the provided ground_truth. Specifically the generator is prompted to
43+
check for mathematical equivalence
44+
45+
Note: This is not a "good" prompt, it just demonstrates how to make one
46+
"""
47+
48+
if ground_truth is None:
49+
raise
50+
51+
system_prompt = f"""
52+
You are a math professor. Given the prompt and ground truth solution, evaluate
53+
each of the provided attempts and return whether the final solution is
54+
numerically equivalent to the ground truth.
55+
56+
Each response is formatted as [Response #<N>], where <N> represents the
57+
attempt.
58+
59+
Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]",
60+
corresponding to the same order as the reponses provided.
61+
62+
- If the answer is irrelevant to the prompt, return "[[BAD]]".
63+
- If you are not confident that solution and attempt are equivalent, return "[[BAD]]"
64+
- Only return "[[GOOD]]" if the attempt is numerically equivalent
65+
66+
Do not explain your reasoning, just provide your evaluations.
67+
---
68+
Here is the prompt that generated the responses: {prompt}.
69+
---
70+
Here is the ground truth: {ground_truth}
71+
"""
72+
response_str = "\n".join(
73+
[f"[Response #{i+1}] {resp}" for i, resp in enumerate(responses)]
74+
)
75+
as_chat = [
76+
{"role": "system", "content": system_prompt},
77+
{"role": "user", "content": response_str},
78+
]
79+
tokenizer = self.processor.tokenizer.tokenizer
80+
formatted_request = tokenizer.apply_chat_template(
81+
as_chat, tokenize=False, add_generation_prompt=True
82+
)
83+
return formatted_request
84+
85+
def _response_check(
86+
self,
87+
prompt: str,
88+
responses: list[str],
89+
ground_truth: None | str = None,
90+
) -> str:
3291
"""
3392
Construct the generator input. Formats the request such that the generator
3493
will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation
@@ -67,7 +126,12 @@ def _response_check(self, prompt: str, responses: list[str]) -> str:
67126
)
68127
return formatted_request
69128

70-
def _best_response(self, prompt: str, responses: list[str]) -> str:
129+
def _best_check(
130+
self,
131+
prompt: str,
132+
responses: list[str],
133+
ground_truth: None | str = None,
134+
) -> str:
71135
"""
72136
Construct the generator input. Format the request such that the generator
73137
will respond with a single integer corresponding to the response the model
@@ -105,14 +169,18 @@ async def evaluate(
105169
self,
106170
prompt: str,
107171
responses: None | list[str] = None,
172+
ground_truth: None | str = None,
108173
evaluation_mode: EvaluationMode = EvaluationMode.BEST_RESPONSE,
109174
) -> list[str]:
110175
_prompting: dict = {
111-
EvaluationMode.BEST_RESPONSE: self._response_check,
112-
EvaluationMode.ANSWER_CHECK: self._answer_check,
176+
EvaluationMode.BEST_RESPONSE: self._best_check,
177+
EvaluationMode.RESPONSE_CHECK: self._response_check,
178+
EvaluationMode.MATH_CHECK: self._math_check,
113179
}
114180

115-
wrapped_prompt: str = _prompting[evaluation_mode](prompt, responses)
181+
wrapped_prompt: str = _prompting[evaluation_mode](
182+
prompt, responses, ground_truth
183+
)
116184
response: List[Completion] = await self.generate._method(self, wrapped_prompt)
117185
return self._postprocess_output(response)
118186

@@ -125,10 +193,14 @@ class RewardModelJudge(Policy):
125193
"""
126194

127195
# TODO: Add reward models formatting
128-
def wrapped_prompt(self, prompt: str, responses: list[str]) -> str:
196+
def wrapped_prompt(
197+
self, prompt: str, responses: list[str], ground_truth: None | str = None
198+
) -> str:
129199
return prompt
130200

131-
def _postprocess_output(self, outputs: list[Completion]) -> list[str]:
201+
def _postprocess_output(
202+
self, outputs: list[Completion], ground_truth: None | str = None
203+
) -> list[str]:
132204
return [output.text for output in outputs]
133205

134206
@endpoint

0 commit comments

Comments
 (0)