Skip to content

Commit 084cb08

Browse files
committed
feat: add KGQualityEvaluator and integrate into EvaluateService for KG evaluations
1 parent 4d022fb commit 084cb08

File tree

5 files changed

+213
-18
lines changed

5 files changed

+213
-18
lines changed

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .evaluator import (
2+
KGQualityEvaluator,
23
LengthEvaluator,
34
MTLDEvaluator,
45
RewardEvaluator,
Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
11
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
2-
from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
2+
from .kg import (
3+
AccuracyEvaluator,
4+
ConsistencyEvaluator,
5+
KGQualityEvaluator,
6+
StructureEvaluator,
7+
)

graphgen/models/evaluator/kg/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
from .accuracy_evaluator import AccuracyEvaluator
1111
from .consistency_evaluator import ConsistencyEvaluator
12+
from .kg_quality_evaluator import KGQualityEvaluator
1213
from .structure_evaluator import StructureEvaluator
1314

1415
__all__ = [
1516
"AccuracyEvaluator",
1617
"ConsistencyEvaluator",
18+
"KGQualityEvaluator",
1719
"StructureEvaluator",
1820
]
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from typing import Any, Dict
2+
3+
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
4+
from graphgen.common import init_llm, init_storage
5+
from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator
6+
from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator
7+
from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator
8+
from graphgen.utils import logger
9+
10+
11+
class KGQualityEvaluator:
12+
def __init__(
13+
self,
14+
working_dir: str = "cache",
15+
graph_backend: str = "kuzu",
16+
kv_backend: str = "rocksdb",
17+
**kwargs
18+
):
19+
# Initialize storage
20+
self.graph_storage: BaseGraphStorage = init_storage(
21+
backend=graph_backend, working_dir=working_dir, namespace="graph"
22+
)
23+
self.chunk_storage: BaseKVStorage = init_storage(
24+
backend=kv_backend, working_dir=working_dir, namespace="chunk"
25+
)
26+
27+
# Initialize LLM client
28+
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
29+
30+
# Initialize individual evaluators
31+
self.accuracy_evaluator = AccuracyEvaluator(
32+
graph_storage=self.graph_storage,
33+
chunk_storage=self.chunk_storage,
34+
llm_client=self.llm_client,
35+
)
36+
37+
self.consistency_evaluator = ConsistencyEvaluator(
38+
graph_storage=self.graph_storage,
39+
chunk_storage=self.chunk_storage,
40+
llm_client=self.llm_client,
41+
)
42+
43+
# Structure evaluator doesn't need chunk_storage or llm_client
44+
structure_params = kwargs.get("structure_params", {})
45+
self.structure_evaluator = StructureEvaluator(
46+
graph_storage=self.graph_storage,
47+
**structure_params
48+
)
49+
50+
logger.info("KGQualityEvaluator initialized")
51+
52+
def evaluate_accuracy(self) -> Dict[str, Any]:
53+
logger.info("Running accuracy evaluation...")
54+
results = self.accuracy_evaluator.evaluate()
55+
logger.info("Accuracy evaluation completed")
56+
return results
57+
58+
def evaluate_consistency(self) -> Dict[str, Any]:
59+
logger.info("Running consistency evaluation...")
60+
results = self.consistency_evaluator.evaluate()
61+
logger.info("Consistency evaluation completed")
62+
return results
63+
64+
def evaluate_structure(self) -> Dict[str, Any]:
65+
logger.info("Running structural robustness evaluation...")
66+
results = self.structure_evaluator.evaluate()
67+
logger.info("Structural robustness evaluation completed")
68+
return results
69+
70+
def evaluate_all(self) -> Dict[str, Any]:
71+
logger.info("Running all KG evaluations...")
72+
results = {
73+
"accuracy": self.evaluate_accuracy(),
74+
"consistency": self.evaluate_consistency(),
75+
"structure": self.evaluate_structure(),
76+
}
77+
logger.info("All KG evaluations completed")
78+
return results
79+

graphgen/operators/evaluate/evaluate_service.py

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Any
1+
from typing import Any, Dict, List, Union
22

33
import pandas as pd
44

55
from graphgen.bases import BaseLLMWrapper, BaseOperator, QAPair
66
from graphgen.common import init_llm
7-
from graphgen.utils import run_concurrent
7+
from graphgen.models import KGQualityEvaluator
8+
from graphgen.utils import logger, run_concurrent
89

910

1011
class EvaluateService(BaseOperator):
@@ -13,40 +14,67 @@ class EvaluateService(BaseOperator):
1314
2. QA Quality Evaluation
1415
"""
1516

16-
def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwargs):
17+
def __init__(
18+
self,
19+
working_dir: str = "cache",
20+
metrics: list[str] = None,
21+
graph_backend: str = "kuzu",
22+
kv_backend: str = "rocksdb",
23+
**kwargs
24+
):
1725
super().__init__(working_dir=working_dir, op_name="evaluate_service")
1826
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
19-
self.metrics = metrics
27+
self.metrics = metrics or []
2028
self.kwargs = kwargs
21-
self.evaluators = {}
29+
self.graph_backend = graph_backend
30+
self.kv_backend = kv_backend
31+
32+
# Separate QA and KG metrics
33+
self.qa_metrics = [m for m in self.metrics if m.startswith("qa_")]
34+
self.kg_metrics = [m for m in self.metrics if m.startswith("kg_")]
35+
36+
# Initialize evaluators
37+
self.qa_evaluators = {}
38+
self.kg_evaluator = None
39+
2240
self._init_evaluators()
2341

2442
def _init_evaluators(self):
25-
for metric in self.metrics:
43+
"""Initialize QA and KG evaluators based on metrics."""
44+
# Initialize QA evaluators
45+
for metric in self.qa_metrics:
2646
if metric == "qa_length":
2747
from graphgen.models import LengthEvaluator
2848

29-
self.evaluators[metric] = LengthEvaluator()
49+
self.qa_evaluators[metric] = LengthEvaluator()
3050
elif metric == "qa_mtld":
3151
from graphgen.models import MTLDEvaluator
32-
33-
self.evaluators[metric] = MTLDEvaluator(
52+
self.qa_evaluators[metric] = MTLDEvaluator(
3453
**self.kwargs.get("mtld_params", {})
3554
)
3655
elif metric == "qa_reward_score":
3756
from graphgen.models import RewardEvaluator
38-
39-
self.evaluators[metric] = RewardEvaluator(
57+
self.qa_evaluators[metric] = RewardEvaluator(
4058
**self.kwargs.get("reward_params", {})
4159
)
4260
elif metric == "qa_uni_score":
4361
from graphgen.models import UniEvaluator
44-
45-
self.evaluators[metric] = UniEvaluator(
62+
self.qa_evaluators[metric] = UniEvaluator(
4663
**self.kwargs.get("uni_params", {})
4764
)
4865
else:
49-
raise ValueError(f"Unknown metric: {metric}")
66+
raise ValueError(f"Unknown QA metric: {metric}")
67+
68+
# Initialize KG evaluator if KG metrics are specified
69+
if self.kg_metrics:
70+
kg_params = self.kwargs.get("kg_params", {})
71+
self.kg_evaluator = KGQualityEvaluator(
72+
working_dir=self.working_dir,
73+
graph_backend=self.graph_backend,
74+
kv_backend=self.kv_backend,
75+
**kg_params
76+
)
77+
logger.info("KG evaluator initialized")
5078

5179
async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
5280
try:
@@ -61,7 +89,7 @@ async def _process_single(self, item: dict[str, Any]) -> dict[str, Any]:
6189
self.logger.error("Error in QAPair creation: %s", str(e))
6290
return {}
6391

64-
for metric, evaluator in self.evaluators.items():
92+
for metric, evaluator in self.qa_evaluators.items():
6593
try:
6694
score = evaluator.evaluate(qa_pair)
6795
if isinstance(score, dict):
@@ -92,18 +120,98 @@ def transform_messages_format(items: list[dict]) -> list[dict]:
92120
transformed.append({"question": question, "answer": answer})
93121
return transformed
94122

95-
def evaluate(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
123+
def _evaluate_qa(self, items: list[dict[str, Any]]) -> list[dict[str, Any]]:
96124
if not items:
97125
return []
98126

127+
if not self.qa_evaluators:
128+
logger.warning("No QA evaluators initialized, skipping QA evaluation")
129+
return []
130+
99131
items = self.transform_messages_format(items)
100132
results = run_concurrent(
101133
self._process_single,
102134
items,
103-
desc="Evaluating items",
135+
desc="Evaluating QA items",
104136
unit="item",
105137
)
106138

107139
results = [item for item in results if item]
140+
return results
108141

142+
def _evaluate_kg(self) -> Dict[str, Any]:
143+
if not self.kg_evaluator:
144+
logger.warning("No KG evaluator initialized, skipping KG evaluation")
145+
return {}
146+
147+
results = {}
148+
149+
# Map metric names to evaluation functions
150+
kg_metric_map = {
151+
"kg_accuracy": self.kg_evaluator.evaluate_accuracy,
152+
"kg_consistency": self.kg_evaluator.evaluate_consistency,
153+
"kg_structure": self.kg_evaluator.evaluate_structure,
154+
}
155+
156+
# Run KG evaluations based on metrics
157+
for metric in self.kg_metrics:
158+
if metric in kg_metric_map:
159+
logger.info("Running %s evaluation...", metric)
160+
metric_key = metric.replace("kg_", "") # Remove "kg_" prefix
161+
try:
162+
results[metric_key] = kg_metric_map[metric]()
163+
except Exception as e:
164+
logger.error("Error in %s evaluation: %s", metric, str(e))
165+
results[metric_key] = {"error": str(e)}
166+
else:
167+
logger.warning("Unknown KG metric: %s, skipping", metric)
168+
169+
# If no valid metrics were found, run all evaluations
170+
if not results:
171+
logger.info("No valid KG metrics found, running all evaluations")
172+
results = self.kg_evaluator.evaluate_all()
173+
109174
return results
175+
176+
def evaluate(
177+
self, items: list[dict[str, Any]] = None
178+
) -> Union[List[Dict[str, Any]], Dict[str, Any]]:
179+
# Determine evaluation type
180+
has_qa_metrics = len(self.qa_metrics) > 0
181+
has_kg_metrics = len(self.kg_metrics) > 0
182+
183+
# If items provided and QA metrics exist, do QA evaluation
184+
if items is not None and has_qa_metrics:
185+
return self._evaluate_qa(items)
186+
187+
# If KG metrics exist, do KG evaluation
188+
if has_kg_metrics:
189+
return self._evaluate_kg()
190+
191+
# If no metrics specified, try to infer from context
192+
if items is not None:
193+
logger.warning("No QA metrics specified but items provided, skipping evaluation")
194+
return []
195+
else:
196+
logger.warning("No metrics specified, skipping evaluation")
197+
return {}
198+
199+
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
200+
has_qa_metrics = len(self.qa_metrics) > 0
201+
has_kg_metrics = len(self.kg_metrics) > 0
202+
203+
# QA evaluation: process batch items
204+
if has_qa_metrics:
205+
items = batch.to_dict(orient="records")
206+
results = self._evaluate_qa(items)
207+
return pd.DataFrame(results)
208+
209+
# KG evaluation: evaluate from storage
210+
if has_kg_metrics:
211+
results = self._evaluate_kg()
212+
# Convert dict to DataFrame (single row)
213+
return pd.DataFrame([results])
214+
215+
# No metrics specified
216+
logger.warning("No metrics specified, returning empty DataFrame")
217+
return pd.DataFrame()

0 commit comments

Comments
 (0)