@@ -18,6 +18,7 @@ class EvaluationMode(Enum):
1818
1919 BEST_RESPONSE = auto ()
2020 RESPONSE_CHECK = auto ()
21+ MATH_CHECK = auto ()
2122
2223
2324@dataclass
@@ -28,7 +29,65 @@ class Judge(Policy):
2829 and may require more postprocessing
2930 """
3031
31- def _response_check (self , prompt : str , responses : list [str ]) -> str :
32+ def _math_check (
33+ self ,
34+ prompt : str ,
35+ responses : list [str ],
36+ ground_truth : None | str = None ,
37+ ) -> str :
38+ """
39+ Construct the generator input. Formats the request such that the generator
40+ will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation
41+ for each response, corresponding to whether the model thinks the response
42+ matches the provided ground_truth. Specifically the generator is prompted to
43+ check for mathematical equivalence
44+
45+ Note: This is not a "good" prompt, it just demonstrates how to make one
46+ """
47+
48+ if ground_truth is None :
49+ raise
50+
51+ system_prompt = f"""
52+ You are a math professor. Given the prompt and ground truth solution, evaluate
53+ each of the provided attempts and return whether the final solution is
54+ numerically equivalent to the ground truth.
55+
56+ Each response is formatted as [Response #<N>], where <N> represents the
57+ attempt.
58+
59+ Your answer should be a comma separated list of "[[GOOD]]" or "[[BAD]]",
60+ corresponding to the same order as the reponses provided.
61+
62+ - If the answer is irrelevant to the prompt, return "[[BAD]]".
63+ - If you are not confident that solution and attempt are equivalent, return "[[BAD]]"
64+ - Only return "[[GOOD]]" if the attempt is numerically equivalent
65+
66+ Do not explain your reasoning, just provide your evaluations.
67+ ---
68+ Here is the prompt that generated the responses: { prompt } .
69+ ---
70+ Here is the ground truth: { ground_truth }
71+ """
72+ response_str = "\n " .join (
73+ [f"[Response #{ i + 1 } ] { resp } " for i , resp in enumerate (responses )]
74+ )
75+ as_chat = [
76+ {"role" : "system" , "content" : system_prompt },
77+ {"role" : "user" , "content" : response_str },
78+ ]
79+ tokenizer = self .processor .tokenizer .tokenizer
80+ formatted_request = tokenizer .apply_chat_template (
81+ as_chat , tokenize = False , add_generation_prompt = True
82+ )
83+ return formatted_request
84+
85+ def _response_check (
86+ self ,
87+ prompt : str ,
88+ responses : list [str ],
89+ ground_truth : None | str = None ,
90+ ) -> str :
3291 """
3392 Construct the generator input. Formats the request such that the generator
3493 will return a comma separated list with a [[GOOD]] or [[BAD]] evaluation
@@ -67,7 +126,12 @@ def _response_check(self, prompt: str, responses: list[str]) -> str:
67126 )
68127 return formatted_request
69128
70- def _best_response (self , prompt : str , responses : list [str ]) -> str :
129+ def _best_check (
130+ self ,
131+ prompt : str ,
132+ responses : list [str ],
133+ ground_truth : None | str = None ,
134+ ) -> str :
71135 """
72136 Construct the generator input. Format the request such that the generator
73137 will respond with a single integer corresponding to the response the model
@@ -105,14 +169,18 @@ async def evaluate(
105169 self ,
106170 prompt : str ,
107171 responses : None | list [str ] = None ,
172+ ground_truth : None | str = None ,
108173 evaluation_mode : EvaluationMode = EvaluationMode .BEST_RESPONSE ,
109174 ) -> list [str ]:
110175 _prompting : dict = {
111- EvaluationMode .BEST_RESPONSE : self ._response_check ,
112- EvaluationMode .ANSWER_CHECK : self ._answer_check ,
176+ EvaluationMode .BEST_RESPONSE : self ._best_check ,
177+ EvaluationMode .RESPONSE_CHECK : self ._response_check ,
178+ EvaluationMode .MATH_CHECK : self ._math_check ,
113179 }
114180
115- wrapped_prompt : str = _prompting [evaluation_mode ](prompt , responses )
181+ wrapped_prompt : str = _prompting [evaluation_mode ](
182+ prompt , responses , ground_truth
183+ )
116184 response : List [Completion ] = await self .generate ._method (self , wrapped_prompt )
117185 return self ._postprocess_output (response )
118186
@@ -125,10 +193,14 @@ class RewardModelJudge(Policy):
125193 """
126194
127195 # TODO: Add reward models formatting
128- def wrapped_prompt (self , prompt : str , responses : list [str ]) -> str :
196+ def wrapped_prompt (
197+ self , prompt : str , responses : list [str ], ground_truth : None | str = None
198+ ) -> str :
129199 return prompt
130200
131- def _postprocess_output (self , outputs : list [Completion ]) -> list [str ]:
201+ def _postprocess_output (
202+ self , outputs : list [Completion ], ground_truth : None | str = None
203+ ) -> list [str ]:
132204 return [output .text for output in outputs ]
133205
134206 @endpoint
0 commit comments