@@ -66,7 +66,7 @@ class AnswerCorrectnessClassification(BaseModel):
6666 "The sun's light plays a critical role in Earth's climate system." ,
6767 "Sunlight helps to drive the weather and ocean currents." ,
6868 ],
69- "extracted_statements " : AnswerCorrectnessClassification .parse_obj (
69+ "classification " : AnswerCorrectnessClassification .parse_obj (
7070 {
7171 "TP" : [
7272 {
@@ -114,7 +114,7 @@ class AnswerCorrectnessClassification(BaseModel):
114114 "The boiling point of water is 100 degrees Celsius (212 degrees Fahrenheit) at sea level." ,
115115 "The boiling point of water can change with altitude." ,
116116 ],
117- "extracted_statements " : AnswerCorrectnessClassification .parse_obj (
117+ "classification " : AnswerCorrectnessClassification .parse_obj (
118118 {
119119 "TP" : [
120120 {
@@ -134,7 +134,7 @@ class AnswerCorrectnessClassification(BaseModel):
134134 },
135135 ],
136136 input_keys = ["question" , "answer" , "ground_truth" ],
137- output_key = "extracted_statements " ,
137+ output_key = "classification " ,
138138 output_type = "json" ,
139139)
140140
@@ -231,26 +231,36 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks, is_async: bool) -> fl
231231 statements [item ].dicts () if statements [item ] is not None else []
232232 )
233233
234- if any (val is [] for val in statements .values ()):
235- return np .nan
236-
237- p_value = self .correctness_prompt .format (
238- question = question ,
239- ground_truth = statements ["ground_truth" ],
240- answer = statements ["answer" ],
241- )
242- is_statement_present = await self .llm .generate (
243- p_value , callbacks = callbacks , is_async = is_async
244- )
245- result_text = is_statement_present .generations [0 ][0 ].text
234+ if not all ([val == [] for val in statements .values ()]):
235+ ground_truth = [
236+ statement
237+ for item in statements ["ground_truth" ]
238+ for statement in item ["simpler_statements" ]
239+ ]
240+ answer = [
241+ statement
242+ for item in statements ["answer" ]
243+ for statement in item ["simpler_statements" ]
244+ ]
245+ p_value = self .correctness_prompt .format (
246+ question = question ,
247+ ground_truth = ground_truth ,
248+ answer = answer ,
249+ )
250+ is_statement_present = await self .llm .generate (
251+ p_value , callbacks = callbacks , is_async = is_async
252+ )
253+ result_text = is_statement_present .generations [0 ][0 ].text
246254
247- answers = await _output_parser .aparse (
248- result_text , p_value , self .llm , self .max_retries
249- )
250- if answers is None :
251- return np .nan
255+ answers = await _output_parser .aparse (
256+ result_text , p_value , self .llm , self .max_retries
257+ )
258+ if answers is None :
259+ return np .nan
252260
253- f1_score = self ._compute_statement_presence (answers )
261+ f1_score = self ._compute_statement_presence (answers )
262+ else :
263+ f1_score = 1.0
254264
255265 if self .weights [1 ] == 0 :
256266 similarity_score = 0.0
0 commit comments