Skip to content

Commit 5bfdc0a

Browse files
committed
fix: fix format and clean up imports
1 parent 777cb25 commit 5bfdc0a

File tree

5 files changed

+80
-77
lines changed

5 files changed

+80
-77
lines changed

graphgen/models/evaluator/kg/accuracy_evaluator.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@
8484

8585
class AccuracyEvaluator:
8686
"""Evaluates accuracy of entity recognition and relation extraction using LLM-as-a-Judge.
87-
87+
8888
For each chunk, uses LLM to evaluate the quality of extracted entities and relations
8989
by comparing them with the original chunk content. Provides multi-dimensional quality
9090
scores (accuracy, completeness, precision).
9191
"""
92-
92+
9393
def __init__(
9494
self,
9595
graph_storage: BaseGraphStorage,
@@ -104,48 +104,48 @@ def __init__(
104104

105105
def evaluate(self) -> Dict[str, Any]:
106106
"""Evaluate entity and relation extraction quality using LLM-as-a-Judge.
107-
107+
108108
Returns:
109109
Dictionary containing entity_accuracy and relation_accuracy metrics.
110110
"""
111111
# 1. Load all chunks from storage
112112
chunks = self._load_chunks_from_storage()
113-
113+
114114
if not chunks:
115115
logger.warning("No chunks found in storage")
116116
return {"error": "No chunks found in storage"}
117-
117+
118118
logger.info(f"Found {len(chunks)} chunks to evaluate")
119-
119+
120120
# 2. Evaluate each chunk
121121
loop = create_event_loop()
122122
entity_evaluations, relation_evaluations = loop.run_until_complete(
123123
self._evaluate_all_chunks(chunks)
124124
)
125-
125+
126126
# 3. Aggregate results
127127
return self._aggregate_evaluation_results(entity_evaluations, relation_evaluations)
128128

129129
def _load_chunks_from_storage(self) -> List[Chunk]:
130130
"""Load all chunks from chunk storage."""
131131
chunks = []
132132
all_chunk_data = self.chunk_storage.get_all()
133-
133+
134134
for chunk_id, chunk_data in all_chunk_data.items():
135135
try:
136136
chunk = Chunk.from_dict(chunk_id, chunk_data)
137137
chunks.append(chunk)
138138
except Exception as e:
139139
logger.warning(f"Failed to load chunk {chunk_id}: {e}")
140140
continue
141-
141+
142142
return chunks
143143

144144
def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
145145
"""Get all entities extracted from the specified chunk."""
146146
entities = []
147147
all_nodes = self.graph_storage.get_all_nodes() or []
148-
148+
149149
for node_id, node_data in all_nodes:
150150
if not isinstance(node_data, dict):
151151
continue
@@ -157,14 +157,14 @@ def _get_extracted_entities_for_chunk(self, chunk_id: str) -> List[Dict]:
157157
"entity_type": node_data.get("entity_type", ""),
158158
"description": node_data.get("description", "")
159159
})
160-
160+
161161
return entities
162162

163163
def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
164164
"""Get all relations extracted from the specified chunk."""
165165
relations = []
166166
all_edges = self.graph_storage.get_all_edges() or []
167-
167+
168168
for src_id, dst_id, edge_data in all_edges:
169169
if not isinstance(edge_data, dict):
170170
continue
@@ -178,39 +178,40 @@ def _get_extracted_relations_for_chunk(self, chunk_id: str) -> List[Dict]:
178178
"target_entity": dst_node.get("entity_name", dst_id),
179179
"relationship_summary": edge_data.get("description", "")
180180
})
181-
181+
182182
return relations
183183

184184
async def _evaluate_all_chunks(
185185
self, chunks: List[Chunk]
186186
) -> tuple[List[Dict], List[Dict]]:
187187
"""Evaluate all chunks concurrently."""
188188
semaphore = asyncio.Semaphore(self.max_concurrent)
189-
189+
190190
async def evaluate_chunk(chunk: Chunk):
191191
async with semaphore:
192192
entities = self._get_extracted_entities_for_chunk(chunk.id)
193193
relations = self._get_extracted_relations_for_chunk(chunk.id)
194-
194+
195195
entity_eval = await self._evaluate_entity_extraction(chunk, entities)
196196
relation_eval = await self._evaluate_relation_extraction(chunk, relations)
197-
197+
198198
return entity_eval, relation_eval
199-
199+
200200
tasks = [evaluate_chunk(chunk) for chunk in chunks]
201201
results = await asyncio.gather(*tasks, return_exceptions=True)
202-
202+
203203
entity_evaluations = []
204204
relation_evaluations = []
205-
205+
206206
for i, result in enumerate(results):
207207
if isinstance(result, Exception):
208208
logger.error(f"Failed to evaluate chunk {chunks[i].id}: {result}")
209209
continue
210+
210211
entity_eval, relation_eval = result
211212
entity_evaluations.append(entity_eval)
212213
relation_evaluations.append(relation_eval)
213-
214+
214215
return entity_evaluations, relation_evaluations
215216

216217
async def _evaluate_entity_extraction(
@@ -222,9 +223,9 @@ async def _evaluate_entity_extraction(
222223
chunk_content=chunk.content,
223224
extracted_entities=json.dumps(extracted_entities, ensure_ascii=False, indent=2)
224225
)
225-
226+
226227
response = await self.llm_client.generate_answer(prompt)
227-
228+
228229
# Try to parse JSON response
229230
try:
230231
evaluation_result = json.loads(response)
@@ -246,14 +247,14 @@ async def _evaluate_entity_extraction(
246247
"precision_reasoning": "",
247248
"issues": ["LLM response parsing failed"]
248249
}
249-
250+
250251
# Validate and calculate overall_score if not provided
251252
if "overall_score" not in evaluation_result:
252253
accuracy = float(evaluation_result.get("accuracy", 0.0))
253254
completeness = float(evaluation_result.get("completeness", 0.0))
254255
precision = float(evaluation_result.get("precision", 0.0))
255256
evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
256-
257+
257258
return {
258259
"chunk_id": chunk.id,
259260
"chunk_content": chunk.content[:200] if chunk.content else "", # First 200 chars for debugging
@@ -285,9 +286,9 @@ async def _evaluate_relation_extraction(
285286
chunk_content=chunk.content,
286287
extracted_relations=json.dumps(extracted_relations, ensure_ascii=False, indent=2)
287288
)
288-
289+
289290
response = await self.llm_client.generate_answer(prompt)
290-
291+
291292
# Try to parse JSON response
292293
try:
293294
evaluation_result = json.loads(response)
@@ -309,14 +310,14 @@ async def _evaluate_relation_extraction(
309310
"precision_reasoning": "",
310311
"issues": ["LLM response parsing failed"]
311312
}
312-
313+
313314
# Validate and calculate overall_score if not provided
314315
if "overall_score" not in evaluation_result:
315316
accuracy = float(evaluation_result.get("accuracy", 0.0))
316317
completeness = float(evaluation_result.get("completeness", 0.0))
317318
precision = float(evaluation_result.get("precision", 0.0))
318319
evaluation_result["overall_score"] = 0.4 * accuracy + 0.4 * completeness + 0.2 * precision
319-
320+
320321
return {
321322
"chunk_id": chunk.id,
322323
"chunk_content": chunk.content[:200] if chunk.content else "",
@@ -358,26 +359,26 @@ def calculate_stats(scores: List[float]) -> Dict[str, float]:
358359
median = sorted_scores[n // 2] if n % 2 == 1 else (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2
359360
variance = sum((x - mean) ** 2 for x in scores) / n
360361
std = variance ** 0.5
361-
362+
362363
return {
363364
"mean": mean,
364365
"median": median,
365366
"min": min(scores),
366367
"max": max(scores),
367368
"std": std
368369
}
369-
370+
370371
# Extract scores
371372
entity_overall_scores = [e.get("overall_score", 0.0) for e in entity_evaluations]
372373
entity_accuracy_scores = [e.get("accuracy", 0.0) for e in entity_evaluations]
373374
entity_completeness_scores = [e.get("completeness", 0.0) for e in entity_evaluations]
374375
entity_precision_scores = [e.get("precision", 0.0) for e in entity_evaluations]
375-
376+
376377
relation_overall_scores = [r.get("overall_score", 0.0) for r in relation_evaluations]
377378
relation_accuracy_scores = [r.get("accuracy", 0.0) for r in relation_evaluations]
378379
relation_completeness_scores = [r.get("completeness", 0.0) for r in relation_evaluations]
379380
relation_precision_scores = [r.get("precision", 0.0) for r in relation_evaluations]
380-
381+
381382
return {
382383
"entity_accuracy": {
383384
"overall_score": calculate_stats(entity_overall_scores),

0 commit comments

Comments
 (0)