Skip to content

Commit 86fa173

Browse files
wip: perf evaluate_service
1 parent f5b2254 commit 86fa173

File tree

14 files changed

+56
-22
lines changed

14 files changed

+56
-22
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
python3 -m graphgen.run \
2-
--config_file examples/evaluate/evaluate_kg/evaluate_kg_config.yaml
2+
--config_file examples/evaluate/evaluate_kg/kg_evaluation_config.yaml

examples/evaluate/evaluate_qa/evaluate.sh

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/evaluate/evaluate_qa/qa_evaluation_config.yaml

examples/evaluate/evaluate_qa/qa_evaluation_config.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,5 +92,7 @@ nodes:
9292
metrics:
9393
- qa_length
9494
- qa_mtld
95-
- qa_reward_score
96-
- qa_uni_score
95+
# - qa_reward_score
96+
# - qa_uni_score
97+
mtld_params:
98+
threshold: 0.7

graphgen/models/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .evaluator import (
2-
KGQualityEvaluator,
32
LengthEvaluator,
43
MTLDEvaluator,
54
RewardEvaluator,

graphgen/models/evaluator/kg/accuracy_evaluator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ async def _evaluate_entity_extraction(
152152
) -> Dict[str, Any]:
153153
"""Use LLM to evaluate entity extraction quality."""
154154
try:
155-
prompt = ENTITY_EVALUATION_PROMPT.format(
155+
lang = detect_main_language(chunk.content)
156+
157+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["ENTITY"].format(
156158
chunk_content=chunk.content,
157159
extracted_entities=json.dumps(
158160
extracted_entities, ensure_ascii=False, indent=2
@@ -225,7 +227,8 @@ async def _evaluate_relation_extraction(
225227
) -> Dict[str, Any]:
226228
"""Use LLM to evaluate relation extraction quality."""
227229
try:
228-
prompt = RELATION_EVALUATION_PROMPT.format(
230+
lang = detect_main_language(chunk.content)
231+
prompt = ACCURACY_EVALUATION_PROMPT[lang]["RELATION"].format(
229232
chunk_content=chunk.content,
230233
extracted_relations=json.dumps(
231234
extracted_relations, ensure_ascii=False, indent=2

graphgen/models/evaluator/qa/length_evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
2+
import os
13
from graphgen.bases import BaseEvaluator, QAPair
24
from graphgen.models.tokenizer import Tokenizer
35

46

57
class LengthEvaluator(BaseEvaluator):
6-
def __init__(self, tokenizer: Tokenizer):
7-
self.tokenizer = tokenizer
8+
def __init__(self):
9+
self.tokenizer: Tokenizer = Tokenizer(os.environ["TOKENIZER_MODEL"] or "cl100k_base")
810

911
def evaluate(self, pair: QAPair) -> float:
1012
"""

graphgen/operators/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .quiz import QuizService
88
from .read import read
99
from .search import SearchService
10+
from .evaluate import EvaluateService
11+
1012

1113
operators = {
1214
"read": read,
@@ -18,4 +20,5 @@
1820
"search": SearchService,
1921
"partition": PartitionService,
2022
"generate": GenerateService,
23+
"evaluate": EvaluateService,
2124
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .evaluate_service import EvaluateService

graphgen/operators/evaluate/evaluate_service.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,35 @@ class EvaluateService(BaseOperator):
1010
2. QA Quality Evaluation
1111
"""
1212

13-
def __init__(self, working_dir: str = "cache", metrics: list[str] = None):
13+
def __init__(self, working_dir: str = "cache", metrics: list[str] = None, **kwargs):
1414
super().__init__(working_dir=working_dir, op_name="evaluate_service")
1515
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
1616
self.metrics = metrics
17-
18-
self.evaluators = {
19-
"xxx": "xxxEvaluator"
20-
}
21-
22-
self.graph_storage = init_storage(
23-
xx, xx, xx
24-
)
17+
self.kwargs = kwargs
18+
self.evaluators = {}
2519

2620
def _init_evaluators(self):
2721
for metric in self.metrics:
28-
22+
if metric == "qa_length":
23+
from graphgen.models import LengthEvaluator
24+
self.evaluators[metric] = LengthEvaluator()
25+
elif metric == "qa_mtld":
26+
from graphgen.models import MTLDEvaluator
27+
self.evaluators[metric] = MTLDEvaluator(self.kwargs.get("mtld_params", {}))
28+
elif metric == "qa_reward_score":
29+
from graphgen.models import RewardEvaluator
30+
self.evaluators[metric] = RewardEvaluator(self.kwargs.get("reward_params", {}))
31+
elif metric == "qa_uni_score":
32+
from graphgen.models import UniEvaluator
33+
self.evaluators[metric] = UniEvaluator(self.kwargs.get("uni_params", {}))
34+
else:
35+
raise ValueError(f"Unknown metric: {metric}")
2936

3037
def process(self, batch: pd.DataFrame) -> pd.DataFrame:
3138
items = batch.to_dict(orient="records")
3239
return pd.DataFrame(self.evaluate(items))
3340

3441
def evaluate(self, items: list[dict]) -> list[dict]:
35-
# 用evaluators 评估 items
42+
print(items)
3643
pass
44+

0 commit comments

Comments
 (0)