1- from typing import Any
1+ from typing import Any , Dict , List , Union
22
33import pandas as pd
44
55from graphgen .bases import BaseLLMWrapper , BaseOperator , QAPair
66from graphgen .common import init_llm
7- from graphgen .utils import run_concurrent
7+ from graphgen .models import KGQualityEvaluator
8+ from graphgen .utils import logger , run_concurrent
89
910
1011class EvaluateService (BaseOperator ):
@@ -13,40 +14,67 @@ class EvaluateService(BaseOperator):
1314 2. QA Quality Evaluation
1415 """
1516
16- def __init__ (self , working_dir : str = "cache" , metrics : list [str ] = None , ** kwargs ):
17+ def __init__ (
18+ self ,
19+ working_dir : str = "cache" ,
20+ metrics : list [str ] = None ,
21+ graph_backend : str = "kuzu" ,
22+ kv_backend : str = "rocksdb" ,
23+ ** kwargs
24+ ):
1725 super ().__init__ (working_dir = working_dir , op_name = "evaluate_service" )
1826 self .llm_client : BaseLLMWrapper = init_llm ("synthesizer" )
19- self .metrics = metrics
27+ self .metrics = metrics or []
2028 self .kwargs = kwargs
21- self .evaluators = {}
29+ self .graph_backend = graph_backend
30+ self .kv_backend = kv_backend
31+
32+ # Separate QA and KG metrics
33+ self .qa_metrics = [m for m in self .metrics if m .startswith ("qa_" )]
34+ self .kg_metrics = [m for m in self .metrics if m .startswith ("kg_" )]
35+
36+ # Initialize evaluators
37+ self .qa_evaluators = {}
38+ self .kg_evaluator = None
39+
2240 self ._init_evaluators ()
2341
2442 def _init_evaluators (self ):
25- for metric in self .metrics :
43+ """Initialize QA and KG evaluators based on metrics."""
44+ # Initialize QA evaluators
45+ for metric in self .qa_metrics :
2646 if metric == "qa_length" :
2747 from graphgen .models import LengthEvaluator
2848
29- self .evaluators [metric ] = LengthEvaluator ()
49+ self .qa_evaluators [metric ] = LengthEvaluator ()
3050 elif metric == "qa_mtld" :
3151 from graphgen .models import MTLDEvaluator
32-
33- self .evaluators [metric ] = MTLDEvaluator (
52+ self .qa_evaluators [metric ] = MTLDEvaluator (
3453 ** self .kwargs .get ("mtld_params" , {})
3554 )
3655 elif metric == "qa_reward_score" :
3756 from graphgen .models import RewardEvaluator
38-
39- self .evaluators [metric ] = RewardEvaluator (
57+ self .qa_evaluators [metric ] = RewardEvaluator (
4058 ** self .kwargs .get ("reward_params" , {})
4159 )
4260 elif metric == "qa_uni_score" :
4361 from graphgen .models import UniEvaluator
44-
45- self .evaluators [metric ] = UniEvaluator (
62+ self .qa_evaluators [metric ] = UniEvaluator (
4663 ** self .kwargs .get ("uni_params" , {})
4764 )
4865 else :
49- raise ValueError (f"Unknown metric: { metric } " )
66+ raise ValueError (f"Unknown QA metric: { metric } " )
67+
68+ # Initialize KG evaluator if KG metrics are specified
69+ if self .kg_metrics :
70+ kg_params = self .kwargs .get ("kg_params" , {})
71+ self .kg_evaluator = KGQualityEvaluator (
72+ working_dir = self .working_dir ,
73+ graph_backend = self .graph_backend ,
74+ kv_backend = self .kv_backend ,
75+ ** kg_params
76+ )
77+ logger .info ("KG evaluator initialized" )
5078
5179 async def _process_single (self , item : dict [str , Any ]) -> dict [str , Any ]:
5280 try :
@@ -61,7 +89,7 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
6189 self .logger .error ("Error in QAPair creation: %s" , str (e ))
6290 return {}
6391
64- for metric , evaluator in self .evaluators .items ():
92+ for metric , evaluator in self .qa_evaluators .items ():
6593 try :
6694 score = evaluator .evaluate (qa_pair )
6795 if isinstance (score , dict ):
@@ -92,18 +120,98 @@ def transform_messages_format(items: list[dict]) -> list[dict]:
92120 transformed .append ({"question" : question , "answer" : answer })
93121 return transformed
94122
95- def evaluate (self , items : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
123+ def _evaluate_qa (self , items : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
96124 if not items :
97125 return []
98126
127+ if not self .qa_evaluators :
128+ logger .warning ("No QA evaluators initialized, skipping QA evaluation" )
129+ return []
130+
99131 items = self .transform_messages_format (items )
100132 results = run_concurrent (
101133 self ._process_single ,
102134 items ,
103- desc = "Evaluating items" ,
135+ desc = "Evaluating QA items" ,
104136 unit = "item" ,
105137 )
106138
107139 results = [item for item in results if item ]
140+ return results
108141
142+ def _evaluate_kg (self ) -> Dict [str , Any ]:
143+ if not self .kg_evaluator :
144+ logger .warning ("No KG evaluator initialized, skipping KG evaluation" )
145+ return {}
146+
147+ results = {}
148+
149+ # Map metric names to evaluation functions
150+ kg_metric_map = {
151+ "kg_accuracy" : self .kg_evaluator .evaluate_accuracy ,
152+ "kg_consistency" : self .kg_evaluator .evaluate_consistency ,
153+ "kg_structure" : self .kg_evaluator .evaluate_structure ,
154+ }
155+
156+ # Run KG evaluations based on metrics
157+ for metric in self .kg_metrics :
158+ if metric in kg_metric_map :
159+ logger .info ("Running %s evaluation..." , metric )
160+ metric_key = metric .replace ("kg_" , "" ) # Remove "kg_" prefix
161+ try :
162+ results [metric_key ] = kg_metric_map [metric ]()
163+ except Exception as e :
164+ logger .error ("Error in %s evaluation: %s" , metric , str (e ))
165+ results [metric_key ] = {"error" : str (e )}
166+ else :
167+ logger .warning ("Unknown KG metric: %s, skipping" , metric )
168+
169+ # If no valid metrics were found, run all evaluations
170+ if not results :
171+ logger .info ("No valid KG metrics found, running all evaluations" )
172+ results = self .kg_evaluator .evaluate_all ()
173+
109174 return results
175+
176+ def evaluate (
177+ self , items : list [dict [str , Any ]] = None
178+ ) -> Union [List [Dict [str , Any ]], Dict [str , Any ]]:
179+ # Determine evaluation type
180+ has_qa_metrics = len (self .qa_metrics ) > 0
181+ has_kg_metrics = len (self .kg_metrics ) > 0
182+
183+ # If items provided and QA metrics exist, do QA evaluation
184+ if items is not None and has_qa_metrics :
185+ return self ._evaluate_qa (items )
186+
187+ # If KG metrics exist, do KG evaluation
188+ if has_kg_metrics :
189+ return self ._evaluate_kg ()
190+
191+ # If no metrics specified, try to infer from context
192+ if items is not None :
193+ logger .warning ("No QA metrics specified but items provided, skipping evaluation" )
194+ return []
195+ else :
196+ logger .warning ("No metrics specified, skipping evaluation" )
197+ return {}
198+
199+ def process (self , batch : pd .DataFrame ) -> pd .DataFrame :
200+ has_qa_metrics = len (self .qa_metrics ) > 0
201+ has_kg_metrics = len (self .kg_metrics ) > 0
202+
203+ # QA evaluation: process batch items
204+ if has_qa_metrics :
205+ items = batch .to_dict (orient = "records" )
206+ results = self ._evaluate_qa (items )
207+ return pd .DataFrame (results )
208+
209+ # KG evaluation: evaluate from storage
210+ if has_kg_metrics :
211+ results = self ._evaluate_kg ()
212+ # Convert dict to DataFrame (single row)
213+ return pd .DataFrame ([results ])
214+
215+ # No metrics specified
216+ logger .warning ("No metrics specified, returning empty DataFrame" )
217+ return pd .DataFrame ()
0 commit comments