Skip to content

Commit a4d7993

Browse files
committed
refactor: change evaluation methods in acc and consistency to sync
1 parent 2a3f09f commit a4d7993

File tree

2 files changed

+61
-76
lines changed

2 files changed

+61
-76
lines changed

graphgen/models/evaluator/kg/accuracy_evaluator.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
77
from graphgen.bases.datatypes import Chunk
88
from graphgen.templates import ACCURACY_EVALUATION_PROMPT
9-
from graphgen.utils import create_event_loop, detect_main_language, logger
9+
from graphgen.utils import detect_main_language, logger
1010

1111

1212
class AccuracyEvaluator:
@@ -43,10 +43,7 @@ def evaluate(self) -> Dict[str, Any]:
4343
logger.info(f"Found {len(chunks)} chunks to evaluate")
4444

4545
# 2. Evaluate each chunk
46-
loop = create_event_loop()
47-
entity_evaluations, relation_evaluations = loop.run_until_complete(
48-
self._evaluate_all_chunks(chunks)
49-
)
46+
entity_evaluations, relation_evaluations = self._evaluate_all_chunks(chunks)
5047

5148
# 3. Aggregate results
5249
return self._aggregate_evaluation_results(
@@ -112,54 +109,47 @@ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
112109

113110
return relations
114111

115-
async def _evaluate_all_chunks(
112+
def _evaluate_all_chunks(
116113
self, chunks: List[Chunk]
117114
) -> tuple[List[Dict], List[Dict]]:
118-
"""Evaluate all chunks concurrently."""
119-
semaphore = asyncio.Semaphore(self.max_concurrent)
115+
"""Evaluate all chunks sequentially."""
116+
entity_evaluations = []
117+
relation_evaluations = []
120118

121-
async def evaluate_chunk(chunk: Chunk):
122-
async with semaphore:
119+
for chunk in chunks:
120+
try:
123121
entities = self._get_extracted_entities_for_chunk(chunk.id)
124122
relations = self._get_extracted_relations_for_chunk(chunk.id)
125123

126-
entity_eval = await self._evaluate_entity_extraction(chunk, entities)
127-
relation_eval = await self._evaluate_relation_extraction(
128-
chunk, relations
129-
)
130-
131-
return entity_eval, relation_eval
124+
entity_eval = self._evaluate_entity_extraction(chunk, entities)
125+
relation_eval = self._evaluate_relation_extraction(chunk, relations)
132126

133-
tasks = [evaluate_chunk(chunk) for chunk in chunks]
134-
results = await asyncio.gather(*tasks, return_exceptions=True)
135-
136-
entity_evaluations = []
137-
relation_evaluations = []
138-
139-
for i, result in enumerate(results):
140-
if isinstance(result, Exception):
141-
logger.error(f"Failed to evaluate chunk {chunks[i].id}: {result}")
127+
entity_evaluations.append(entity_eval)
128+
relation_evaluations.append(relation_eval)
129+
except Exception as e:
130+
logger.error(f"Failed to evaluate chunk {chunk.id}: {e}")
142131
continue
143132

144-
entity_eval, relation_eval = result
145-
entity_evaluations.append(entity_eval)
146-
relation_evaluations.append(relation_eval)
147-
148133
return entity_evaluations, relation_evaluations
149134

150-
async def _evaluate_entity_extraction(
135+
def _evaluate_entity_extraction(
151136
self, chunk: Chunk, extracted_entities: List[Dict]
152137
) -> Dict[str, Any]:
153138
"""Use LLM to evaluate entity extraction quality."""
154139
try:
155-
prompt = ENTITY_EVALUATION_PROMPT.format(
140+
lang = detect_main_language(chunk.content)
141+
prompt_template = ACCURACY_EVALUATION_PROMPT.get(lang, {}).get("ENTITY")
142+
if not prompt_template:
143+
prompt_template = ACCURACY_EVALUATION_PROMPT.get("en", {}).get("ENTITY")
144+
145+
prompt = prompt_template.format(
156146
chunk_content=chunk.content,
157147
extracted_entities=json.dumps(
158148
extracted_entities, ensure_ascii=False, indent=2
159149
),
160150
)
161151

162-
response = await self.llm_client.generate_answer(prompt)
152+
response = asyncio.run(self.llm_client.generate_answer(prompt))
163153

164154
# Try to parse JSON response
165155
try:
@@ -220,19 +210,24 @@ async def _evaluate_entity_extraction(
220210
"issues": [f"Evaluation error: {str(e)}"],
221211
}
222212

223-
async def _evaluate_relation_extraction(
213+
def _evaluate_relation_extraction(
224214
self, chunk: Chunk, extracted_relations: List[Dict]
225215
) -> Dict[str, Any]:
226216
"""Use LLM to evaluate relation extraction quality."""
227217
try:
228-
prompt = RELATION_EVALUATION_PROMPT.format(
218+
lang = detect_main_language(chunk.content)
219+
prompt_template = ACCURACY_EVALUATION_PROMPT.get(lang, {}).get("RELATION")
220+
if not prompt_template:
221+
prompt_template = ACCURACY_EVALUATION_PROMPT.get("en", {}).get("RELATION")
222+
223+
prompt = prompt_template.format(
229224
chunk_content=chunk.content,
230225
extracted_relations=json.dumps(
231226
extracted_relations, ensure_ascii=False, indent=2
232227
),
233228
)
234229

235-
response = await self.llm_client.generate_answer(prompt)
230+
response = asyncio.run(self.llm_client.generate_answer(prompt))
236231

237232
# Try to parse JSON response
238233
try:

graphgen/models/evaluator/kg/consistency_evaluator.py

Lines changed: 31 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@
55

66
from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper
77
from graphgen.bases.datatypes import Chunk
8-
from graphgen.utils import create_event_loop, logger
8+
from graphgen.templates.evaluation.kg.consistency_evaluation import (
9+
ENTITY_DESCRIPTION_CONFLICT_PROMPT,
10+
ENTITY_EXTRACTION_PROMPT,
11+
ENTITY_TYPE_CONFLICT_PROMPT,
12+
RELATION_CONFLICT_PROMPT,
13+
)
14+
from graphgen.utils import logger
915

1016

1117
class ConsistencyEvaluator:
@@ -20,24 +26,21 @@ def __init__(
2026
graph_storage: BaseGraphStorage,
2127
chunk_storage: BaseKVStorage,
2228
llm_client: BaseLLMWrapper,
23-
max_concurrent: int = 10,
2429
):
2530
self.graph_storage = graph_storage
2631
self.chunk_storage = chunk_storage
2732
self.llm_client = llm_client
28-
self.max_concurrent = max_concurrent
2933

3034
def evaluate(self) -> Dict[str, Any]:
3135
"""Evaluate consistency by detecting semantic conflicts."""
3236
all_nodes = self.graph_storage.get_all_nodes() or []
3337
if not all_nodes:
3438
return {"error": "Empty graph"}
3539

36-
loop = create_event_loop()
37-
return loop.run_until_complete(self._evaluate_consistency(all_nodes))
40+
return self._evaluate_consistency(all_nodes)
3841

39-
async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
40-
"""Async evaluation of consistency."""
42+
def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
43+
"""Evaluate consistency by detecting semantic conflicts."""
4144
# Filter entities with multiple source chunks
4245
entities_with_multiple_sources = []
4346
for node_id, node_data in all_nodes:
@@ -63,35 +66,22 @@ async def _evaluate_consistency(self, all_nodes: List) -> Dict[str, Any]:
6366
f"Checking consistency for {len(entities_with_multiple_sources)} entities with multiple sources"
6467
)
6568

66-
# Evaluate entities concurrently
67-
semaphore = asyncio.Semaphore(self.max_concurrent)
68-
69-
async def evaluate_entity(entity_info):
70-
async with semaphore:
71-
return await self._evaluate_entity_consistency(entity_info)
72-
73-
tasks = [
74-
evaluate_entity(entity_info)
75-
for entity_info in entities_with_multiple_sources
76-
]
77-
results = await asyncio.gather(*tasks, return_exceptions=True)
78-
79-
# Aggregate results
69+
# Evaluate entities sequentially
8070
conflicts = []
8171
conflict_entities = set()
8272

83-
for i, result in enumerate(results):
84-
if isinstance(result, Exception):
73+
for entity_info in entities_with_multiple_sources:
74+
try:
75+
entity_id, entity_conflicts = self._evaluate_entity_consistency(entity_info)
76+
if entity_conflicts:
77+
conflicts.extend(entity_conflicts)
78+
conflict_entities.add(entity_id)
79+
except Exception as e:
8580
logger.error(
86-
f"Failed to evaluate entity {entities_with_multiple_sources[i][0]}: {result}"
81+
f"Failed to evaluate entity {entity_info[0]}: {e}"
8782
)
8883
continue
8984

90-
entity_id, entity_conflicts = result
91-
if entity_conflicts:
92-
conflicts.extend(entity_conflicts)
93-
conflict_entities.add(entity_id)
94-
9585
total_entities = len(all_nodes)
9686
conflict_rate = (
9787
len(conflict_entities) / total_entities if total_entities > 0 else 0
@@ -114,7 +104,7 @@ def _clean_entity_id(self, entity_id: str) -> str:
114104
clean_id = clean_id[1:-1].strip()
115105
return clean_id
116106

117-
async def _evaluate_entity_consistency(
107+
def _evaluate_entity_consistency(
118108
self, entity_info: tuple
119109
) -> tuple[str, List[Dict]]:
120110
"""Evaluate consistency for a single entity."""
@@ -131,7 +121,7 @@ async def _evaluate_entity_consistency(
131121
# Extract entity attributes from each chunk
132122
entity_extractions = {}
133123
for chunk in chunks:
134-
extraction = await self._extract_entity_from_chunk(entity_id, chunk)
124+
extraction = self._extract_entity_from_chunk(entity_id, chunk)
135125
if extraction:
136126
entity_extractions[chunk.id] = extraction
137127

@@ -143,7 +133,7 @@ async def _evaluate_entity_consistency(
143133
chunk_id: ext.get("entity_type", "")
144134
for chunk_id, ext in entity_extractions.items()
145135
}
146-
type_conflict = await self._check_entity_type_consistency(
136+
type_conflict = self._check_entity_type_consistency(
147137
entity_id, type_extractions
148138
)
149139
if type_conflict and type_conflict.get("has_conflict", False):
@@ -163,7 +153,7 @@ async def _evaluate_entity_consistency(
163153
chunk_id: ext.get("description", "")
164154
for chunk_id, ext in entity_extractions.items()
165155
}
166-
desc_conflict = await self._check_entity_description_consistency(
156+
desc_conflict = self._check_entity_description_consistency(
167157
entity_id, descriptions
168158
)
169159
if desc_conflict and desc_conflict.get("has_conflict", False):
@@ -196,7 +186,7 @@ def _get_entity_chunks(self, source_ids: List[str]) -> List[Chunk]:
196186
continue
197187
return chunks
198188

199-
async def _extract_entity_from_chunk(
189+
def _extract_entity_from_chunk(
200190
self, entity_id: str, chunk: Chunk
201191
) -> Dict[str, str]:
202192
"""Extract entity attributes from a chunk using LLM."""
@@ -211,7 +201,7 @@ async def _extract_entity_from_chunk(
211201
else "", # Limit content length
212202
)
213203

214-
response = await self.llm_client.generate_answer(prompt)
204+
response = asyncio.run(self.llm_client.generate_answer(prompt))
215205

216206
# Try to parse JSON response
217207
try:
@@ -265,7 +255,7 @@ async def _extract_entity_from_chunk(
265255
)
266256
return {}
267257

268-
async def _check_entity_type_consistency(
258+
def _check_entity_type_consistency(
269259
self, entity_id: str, type_extractions: Dict[str, str]
270260
) -> Dict[str, Any]:
271261
"""Check entity type consistency using LLM."""
@@ -284,7 +274,7 @@ async def _check_entity_type_consistency(
284274
entity_name=entity_id, type_extractions="\n".join(type_list)
285275
)
286276

287-
response = await self.llm_client.generate_answer(prompt)
277+
response = asyncio.run(self.llm_client.generate_answer(prompt))
288278

289279
# Parse JSON response
290280
try:
@@ -304,7 +294,7 @@ async def _check_entity_type_consistency(
304294
logger.error(f"Error checking type consistency for {entity_id}: {e}")
305295
return {"has_conflict": False}
306296

307-
async def _check_entity_description_consistency(
297+
def _check_entity_description_consistency(
308298
self, entity_id: str, descriptions: Dict[str, str]
309299
) -> Dict[str, Any]:
310300
"""Check entity description consistency using LLM."""
@@ -327,7 +317,7 @@ async def _check_entity_description_consistency(
327317
entity_name=entity_id, descriptions="\n".join(desc_list)
328318
)
329319

330-
response = await self.llm_client.generate_answer(prompt)
320+
response = asyncio.run(self.llm_client.generate_answer(prompt))
331321

332322
# Parse JSON response
333323
try:
@@ -347,7 +337,7 @@ async def _check_entity_description_consistency(
347337
logger.error(f"Error checking description consistency for {entity_id}: {e}")
348338
return {"has_conflict": False}
349339

350-
async def _check_relation_consistency(
340+
def _check_relation_consistency(
351341
self, src_id: str, dst_id: str, relation_extractions: Dict[str, str]
352342
) -> Dict[str, Any]:
353343
"""Check relation consistency using LLM."""
@@ -367,7 +357,7 @@ async def _check_relation_consistency(
367357
relation_descriptions="\n".join(rel_list),
368358
)
369359

370-
response = await self.llm_client.generate_answer(prompt)
360+
response = asyncio.run(self.llm_client.generate_answer(prompt))
371361

372362
# Parse JSON response
373363
try:

0 commit comments

Comments
 (0)