11from typing import Any
2+
23import pandas as pd
34
45from graphgen .bases import BaseLLMWrapper , BaseOperator , QAPair
@@ -18,21 +19,32 @@ def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwar
1819 self .metrics = metrics
1920 self .kwargs = kwargs
2021 self .evaluators = {}
22+ self ._init_evaluators ()
2123
2224 def _init_evaluators (self ):
2325 for metric in self .metrics :
2426 if metric == "qa_length" :
2527 from graphgen .models import LengthEvaluator
28+
2629 self .evaluators [metric ] = LengthEvaluator ()
2730 elif metric == "qa_mtld" :
2831 from graphgen .models import MTLDEvaluator
29- self .evaluators [metric ] = MTLDEvaluator (self .kwargs .get ("mtld_params" , {}))
32+
33+ self .evaluators [metric ] = MTLDEvaluator (
34+ ** self .kwargs .get ("mtld_params" , {})
35+ )
3036 elif metric == "qa_reward_score" :
3137 from graphgen .models import RewardEvaluator
32- self .evaluators [metric ] = RewardEvaluator (self .kwargs .get ("reward_params" , {}))
38+
39+ self .evaluators [metric ] = RewardEvaluator (
40+ ** self .kwargs .get ("reward_params" , {})
41+ )
3342 elif metric == "qa_uni_score" :
3443 from graphgen .models import UniEvaluator
35- self .evaluators [metric ] = UniEvaluator (self .kwargs .get ("uni_params" , {}))
44+
45+ self .evaluators [metric ] = UniEvaluator (
46+ ** self .kwargs .get ("uni_params" , {})
47+ )
3648 else :
3749 raise ValueError (f"Unknown metric: { metric } " )
3850
@@ -44,16 +56,13 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
4456 try :
4557 qa_pair = QAPair (
4658 question = str (item .get ("question" , "" )),
47- answer = str (item .get ("answer" , "" ))
59+ answer = str (item .get ("answer" , "" )),
4860 )
4961 if not qa_pair .question or not qa_pair .answer :
5062 self .logger .error ("Empty question or answer, skipping." )
5163 return {}
5264 except Exception as e :
53- self .logger .error (
54- "Error in QAPair creation: %s" ,
55- str (e )
56- )
65+ self .logger .error ("Error in QAPair creation: %s" , str (e ))
5766 return {}
5867
5968 for metric , evaluator in self .evaluators .items ():
@@ -65,17 +74,33 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
6574 else :
6675 item [metric ] = float (score )
6776 except Exception as e :
68- self .logger .error (
69- "Error in %s evaluation: %s" ,
70- metric ,
71- str (e )
72- )
77+ self .logger .error ("Error in %s evaluation: %s" , metric , str (e ))
7378 item [metric ] = None
79+ return item
80+
81+ @staticmethod
82+ def transform_messages_format (items : list [dict ]) -> list [dict ]:
83+ """
84+ Transform from [{'messages': [...]}, ...] to [{'question': '...', 'answer': '...'}, ...]
85+ """
86+ transformed = []
87+ for item in items :
88+ messages = item .get ("messages" , [])
89+ question = next (
90+ (m ["content" ] for m in messages if m .get ("role" ) == "user" ), ""
91+ )
92+ answer = next (
93+ (m ["content" ] for m in messages if m .get ("role" ) == "assistant" ), ""
94+ )
95+
96+ transformed .append ({"question" : question , "answer" : answer })
97+ return transformed
7498
7599 def evaluate (self , items : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
76100 if not items :
77101 return []
78102
103+ items = self .transform_messages_format (items )
79104 results = run_concurrent (
80105 self ._process_single ,
81106 items ,
0 commit comments