Skip to content

Commit 3ae2321

Browse files
committed
refactor: streamline evaluation functions for accuracy, consistency, and structure
1 parent a4d7993 commit 3ae2321

File tree

1 file changed

+59
-295
lines changed

1 file changed

+59
-295
lines changed
Lines changed: 59 additions & 295 deletions
Original file line numberDiff line numberDiff line change
@@ -1,307 +1,71 @@
1-
import argparse
2-
import json
3-
from pathlib import Path
1+
from typing import Any, Dict
42

53
from dotenv import load_dotenv
64

75
from graphgen.models import KGQualityEvaluator
8-
from graphgen.utils import CURRENT_LOGGER_VAR, logger, set_logger
6+
from graphgen.utils import logger
97

108
# Load environment variables
119
load_dotenv()
1210

1311

14-
def _run_evaluation(evaluator, args):
15-
"""Run the evaluation based on arguments."""
16-
if args.accuracy_only:
17-
logger.info("Running accuracy evaluation only...")
18-
return {"accuracy": evaluator.evaluate_accuracy()}
19-
if args.consistency_only:
20-
logger.info("Running consistency evaluation only...")
21-
return {"consistency": evaluator.evaluate_consistency()}
22-
if args.structure_only:
23-
logger.info("Running structural robustness evaluation only...")
24-
return {"structure": evaluator.evaluate_structure()}
12+
def evaluate_accuracy(evaluator: KGQualityEvaluator) -> Dict[str, Any]:
13+
"""Evaluate accuracy of entity and relation extraction.
14+
15+
Args:
16+
evaluator: KGQualityEvaluator instance
17+
18+
Returns:
19+
Dictionary containing entity_accuracy and relation_accuracy metrics.
20+
"""
21+
logger.info("Running accuracy evaluation...")
22+
results = evaluator.evaluate_accuracy()
23+
logger.info("Accuracy evaluation completed")
24+
return results
25+
26+
27+
def evaluate_consistency(evaluator: KGQualityEvaluator) -> Dict[str, Any]:
28+
"""Evaluate consistency by detecting semantic conflicts.
29+
30+
Args:
31+
evaluator: KGQualityEvaluator instance
32+
33+
Returns:
34+
Dictionary containing consistency metrics including conflict_rate and conflicts.
35+
"""
36+
logger.info("Running consistency evaluation...")
37+
results = evaluator.evaluate_consistency()
38+
logger.info("Consistency evaluation completed")
39+
return results
40+
41+
42+
def evaluate_structure(evaluator: KGQualityEvaluator) -> Dict[str, Any]:
43+
"""Evaluate structural robustness of the graph.
44+
45+
Args:
46+
evaluator: KGQualityEvaluator instance
47+
48+
Returns:
49+
Dictionary containing structural metrics including noise_ratio, largest_cc_ratio, etc.
50+
"""
51+
logger.info("Running structural robustness evaluation...")
52+
results = evaluator.evaluate_structure()
53+
logger.info("Structural robustness evaluation completed")
54+
return results
55+
56+
57+
def evaluate_all(evaluator: KGQualityEvaluator) -> Dict[str, Any]:
58+
"""Run all evaluations (accuracy, consistency, structure).
59+
60+
Args:
61+
evaluator: KGQualityEvaluator instance
62+
63+
Returns:
64+
Dictionary containing all evaluation results with keys: accuracy, consistency, structure.
65+
"""
2566
logger.info("Running all evaluations...")
26-
return evaluator.evaluate_all()
67+
results = evaluator.evaluate_all()
68+
logger.info("All evaluations completed")
69+
return results
2770

2871

29-
def _print_accuracy_summary(acc):
30-
"""Print accuracy evaluation summary."""
31-
if "error" not in acc:
32-
print("\n[Accuracy]")
33-
if "entity_accuracy" in acc:
34-
e = acc["entity_accuracy"]
35-
overall = e.get("overall_score", {})
36-
accuracy = e.get("accuracy", {})
37-
completeness = e.get("completeness", {})
38-
precision = e.get("precision", {})
39-
40-
print(" Entity Extraction Quality:")
41-
print(
42-
f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
43-
f"{overall.get('median', 0):.3f} (median)"
44-
)
45-
print(
46-
f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
47-
f"{accuracy.get('median', 0):.3f} (median)"
48-
)
49-
print(
50-
f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
51-
f"{completeness.get('median', 0):.3f} (median)"
52-
)
53-
print(
54-
f" Precision: {precision.get('mean', 0):.3f} (mean), "
55-
f"{precision.get('median', 0):.3f} (median)"
56-
)
57-
print(f" Total Chunks Evaluated: {e.get('total_chunks', 0)}")
58-
59-
if "relation_accuracy" in acc:
60-
r = acc["relation_accuracy"]
61-
overall = r.get("overall_score", {})
62-
accuracy = r.get("accuracy", {})
63-
completeness = r.get("completeness", {})
64-
precision = r.get("precision", {})
65-
66-
print(" Relation Extraction Quality:")
67-
print(
68-
f" Overall Score: {overall.get('mean', 0):.3f} (mean), "
69-
f"{overall.get('median', 0):.3f} (median)"
70-
)
71-
print(
72-
f" Accuracy: {accuracy.get('mean', 0):.3f} (mean), "
73-
f"{accuracy.get('median', 0):.3f} (median)"
74-
)
75-
print(
76-
f" Completeness: {completeness.get('mean', 0):.3f} (mean), "
77-
f"{completeness.get('median', 0):.3f} (median)"
78-
)
79-
print(
80-
f" Precision: {precision.get('mean', 0):.3f} (mean), "
81-
f"{precision.get('median', 0):.3f} (median)"
82-
)
83-
print(f" Total Chunks Evaluated: {r.get('total_chunks', 0)}")
84-
else:
85-
print(f"\n[Accuracy] Error: {acc['error']}")
86-
87-
88-
def _print_consistency_summary(cons):
89-
"""Print consistency evaluation summary."""
90-
if "error" not in cons:
91-
print("\n[Consistency]")
92-
print(f" Conflict Rate: {cons.get('conflict_rate', 0):.3f}")
93-
print(
94-
f" Conflict Entities: {cons.get('conflict_entities_count', 0)} / "
95-
f"{cons.get('total_entities', 0)}"
96-
)
97-
entities_checked = cons.get("entities_checked", 0)
98-
if entities_checked > 0:
99-
print(
100-
f" Entities Checked: {entities_checked} (entities with multiple sources)"
101-
)
102-
conflicts = cons.get("conflicts", [])
103-
if conflicts:
104-
print(f" Total Conflicts Found: {len(conflicts)}")
105-
# Show sample conflicts
106-
sample_conflicts = conflicts[:3]
107-
for conflict in sample_conflicts:
108-
print(
109-
f" - {conflict.get('entity_id', 'N/A')}: {conflict.get('conflict_type', 'N/A')} "
110-
f"(severity: {conflict.get('conflict_severity', 0):.2f})"
111-
)
112-
else:
113-
print(f"\n[Consistency] Error: {cons['error']}")
114-
115-
116-
def _print_structure_summary(struct):
117-
"""Print structural robustness evaluation summary."""
118-
if "error" not in struct:
119-
print("\n[Structural Robustness]")
120-
print(f" Total Nodes: {struct.get('total_nodes', 0)}")
121-
print(f" Total Edges: {struct.get('total_edges', 0)}")
122-
123-
thresholds = struct.get("thresholds", {})
124-
125-
# Noise Ratio
126-
noise_check = thresholds.get("noise_ratio", {})
127-
noise_threshold = noise_check.get("threshold", "N/A")
128-
noise_pass = noise_check.get("pass", False)
129-
print(
130-
f" Noise Ratio: {struct.get('noise_ratio', 0):.3f} "
131-
f"({'✓' if noise_pass else '✗'} < {noise_threshold})"
132-
)
133-
134-
# Largest CC Ratio
135-
lcc_check = thresholds.get("largest_cc_ratio", {})
136-
lcc_threshold = lcc_check.get("threshold", "N/A")
137-
lcc_pass = lcc_check.get("pass", False)
138-
print(
139-
f" Largest CC Ratio: {struct.get('largest_cc_ratio', 0):.3f} "
140-
f"({'✓' if lcc_pass else '✗'} > {lcc_threshold})"
141-
)
142-
143-
# Avg Degree
144-
avg_degree_check = thresholds.get("avg_degree", {})
145-
avg_degree_threshold = avg_degree_check.get("threshold", "N/A")
146-
avg_degree_pass = avg_degree_check.get("pass", False)
147-
# Format threshold for display (handle tuple case)
148-
if isinstance(avg_degree_threshold, tuple):
149-
threshold_str = f"{avg_degree_threshold[0]}-{avg_degree_threshold[1]}"
150-
else:
151-
threshold_str = str(avg_degree_threshold)
152-
print(
153-
f" Avg Degree: {struct.get('avg_degree', 0):.2f} "
154-
f"({'✓' if avg_degree_pass else '✗'} {threshold_str})"
155-
)
156-
157-
# Power Law R²
158-
if struct.get("powerlaw_r2") is not None:
159-
powerlaw_check = thresholds.get("powerlaw_r2", {})
160-
powerlaw_threshold = powerlaw_check.get("threshold", "N/A")
161-
powerlaw_pass = powerlaw_check.get("pass", False)
162-
print(
163-
f" Power Law R²: {struct.get('powerlaw_r2', 0):.3f} "
164-
f"({'✓' if powerlaw_pass else '✗'} > {powerlaw_threshold})"
165-
)
166-
else:
167-
print(f"\n[Structural Robustness] Error: {struct['error']}")
168-
169-
170-
def _print_summary(results):
171-
"""Print evaluation summary."""
172-
print("\n" + "=" * 60)
173-
print("KG Quality Evaluation Summary")
174-
print("=" * 60)
175-
176-
if "accuracy" in results:
177-
_print_accuracy_summary(results["accuracy"])
178-
if "consistency" in results:
179-
_print_consistency_summary(results["consistency"])
180-
if "structure" in results:
181-
_print_structure_summary(results["structure"])
182-
183-
print("\n" + "=" * 60)
184-
185-
186-
def main():
187-
"""Main function to run KG quality evaluation."""
188-
parser = argparse.ArgumentParser(
189-
description="Evaluate knowledge graph quality",
190-
formatter_class=argparse.RawDescriptionHelpFormatter,
191-
epilog="""
192-
Examples:
193-
# Basic evaluation
194-
python -m graphgen.operators.evaluate_kg.evaluate_kg --working_dir cache
195-
196-
# Custom output
197-
python -m graphgen.operators.evaluate_kg.evaluate_kg \\
198-
--working_dir cache \\
199-
--output cache/kg_evaluation.json
200-
201-
# Specify backends
202-
python -m graphgen.operators.evaluate_kg.evaluate_kg \\
203-
--working_dir cache \\
204-
--graph_backend networkx \\
205-
--kv_backend json_kv
206-
""",
207-
)
208-
209-
parser.add_argument(
210-
"--working_dir",
211-
type=str,
212-
default="cache",
213-
help="Working directory containing graph and chunk storage (default: cache)",
214-
)
215-
parser.add_argument(
216-
"--graph_backend",
217-
type=str,
218-
default="kuzu",
219-
choices=["kuzu", "networkx"],
220-
help="Graph storage backend (default: kuzu)",
221-
)
222-
parser.add_argument(
223-
"--kv_backend",
224-
type=str,
225-
default="rocksdb",
226-
choices=["rocksdb", "json_kv"],
227-
help="KV storage backend (default: rocksdb)",
228-
)
229-
parser.add_argument(
230-
"--max_concurrent",
231-
type=int,
232-
default=10,
233-
help="Maximum concurrent LLM requests (default: 10)",
234-
)
235-
parser.add_argument(
236-
"--output",
237-
type=str,
238-
default=None,
239-
help="Output file path for evaluation results (default: working_dir/kg_evaluation.json)",
240-
)
241-
parser.add_argument(
242-
"--accuracy_only",
243-
action="store_true",
244-
help="Only run accuracy evaluation",
245-
)
246-
parser.add_argument(
247-
"--consistency_only",
248-
action="store_true",
249-
help="Only run consistency evaluation",
250-
)
251-
parser.add_argument(
252-
"--structure_only",
253-
action="store_true",
254-
help="Only run structural robustness evaluation",
255-
)
256-
257-
args = parser.parse_args()
258-
259-
# Set up logging
260-
log_dir = Path(args.working_dir) / "logs"
261-
log_dir.mkdir(parents=True, exist_ok=True)
262-
default_logger = set_logger(str(log_dir / "evaluate_kg.log"), name="evaluate_kg")
263-
CURRENT_LOGGER_VAR.set(default_logger)
264-
265-
# Determine output path
266-
if args.output is None:
267-
output_path = Path(args.working_dir) / "kg_evaluation.json"
268-
else:
269-
output_path = Path(args.output)
270-
271-
logger.info("Starting KG quality evaluation...")
272-
logger.info(f"Working directory: {args.working_dir}")
273-
logger.info(f"Graph backend: {args.graph_backend}")
274-
logger.info(f"KV backend: {args.kv_backend}")
275-
276-
try:
277-
evaluator = KGQualityEvaluator(
278-
working_dir=args.working_dir,
279-
graph_backend=args.graph_backend,
280-
kv_backend=args.kv_backend,
281-
max_concurrent=args.max_concurrent,
282-
)
283-
except Exception as e:
284-
logger.error(f"Failed to initialize evaluator: {e}")
285-
raise
286-
287-
# Run evaluation
288-
try:
289-
results = _run_evaluation(evaluator, args)
290-
291-
# Save results
292-
output_path.parent.mkdir(parents=True, exist_ok=True)
293-
with open(output_path, "w", encoding="utf-8") as f:
294-
json.dump(results, f, indent=2, ensure_ascii=False)
295-
296-
logger.info(f"Evaluation completed. Results saved to: {output_path}")
297-
298-
# Print summary
299-
_print_summary(results)
300-
301-
except Exception as e:
302-
logger.error(f"Evaluation failed: {e}", exc_info=True)
303-
raise
304-
305-
306-
if __name__ == "__main__":
307-
main()

0 commit comments

Comments
 (0)