@@ -10,27 +10,35 @@ class EvaluateService(BaseOperator):
1010 2. QA Quality Evaluation
1111 """
1212
13- def __init__ (self , working_dir : str = "cache" , metrics : list [str ] = None ):
13+ def __init__ (self , working_dir : str = "cache" , metrics : list [str ] = None , ** kwargs ):
1414 super ().__init__ (working_dir = working_dir , op_name = "evaluate_service" )
1515 self .llm_client : BaseLLMWrapper = init_llm ("synthesizer" )
1616 self .metrics = metrics
17-
18- self .evaluators = {
19- "xxx" : "xxxEvaluator"
20- }
21-
22- self .graph_storage = init_storage (
23- xx , xx , xx
24- )
17+ self .kwargs = kwargs
18+ self .evaluators = {}
2519
2620 def _init_evaluators (self ):
2721 for metric in self .metrics :
28-
22+ if metric == "qa_length" :
23+ from graphgen .models import LengthEvaluator
24+ self .evaluators [metric ] = LengthEvaluator ()
25+ elif metric == "qa_mtld" :
26+ from graphgen .models import MTLDEvaluator
27+ self .evaluators [metric ] = MTLDEvaluator (self .kwargs .get ("mtld_params" , {}))
28+ elif metric == "qa_reward_score" :
29+ from graphgen .models import RewardEvaluator
30+ self .evaluators [metric ] = RewardEvaluator (self .kwargs .get ("reward_params" , {}))
31+ elif metric == "qa_uni_score" :
32+ from graphgen .models import UniEvaluator
33+ self .evaluators [metric ] = UniEvaluator (self .kwargs .get ("uni_params" , {}))
34+ else :
35+ raise ValueError (f"Unknown metric: { metric } " )
2936
3037 def process (self , batch : pd .DataFrame ) -> pd .DataFrame :
3138 items = batch .to_dict (orient = "records" )
3239 return pd .DataFrame (self .evaluate (items ))
3340
3441 def evaluate (self , items : list [dict ]) -> list [dict ]:
35- # 用evaluators 评估 items
42+ print ( items )
3643 pass
44+
0 commit comments