|
2 | 2 |
|
3 | 3 | from dotenv import load_dotenv |
4 | 4 |
|
5 | | -from graphgen.models import KGQualityEvaluator |
| 5 | +from graphgen.bases import BaseGraphStorage, BaseKVStorage, BaseLLMWrapper |
| 6 | +from graphgen.common import init_llm, init_storage |
| 7 | +from graphgen.models.evaluator.kg.accuracy_evaluator import AccuracyEvaluator |
| 8 | +from graphgen.models.evaluator.kg.consistency_evaluator import ConsistencyEvaluator |
| 9 | +from graphgen.models.evaluator.kg.structure_evaluator import StructureEvaluator |
6 | 10 | from graphgen.utils import logger |
7 | 11 |
|
8 | 12 | # Load environment variables |
9 | 13 | load_dotenv() |
10 | 14 |
|
11 | 15 |
|
12 | | -def evaluate_accuracy(evaluator: KGQualityEvaluator) -> Dict[str, Any]: |
13 | | - """Evaluate accuracy of entity and relation extraction. |
14 | | - |
15 | | - Args: |
16 | | - evaluator: KGQualityEvaluator instance |
| 16 | +class KGEvaluators: |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + working_dir: str = "cache", |
| 20 | + graph_backend: str = "kuzu", |
| 21 | + kv_backend: str = "rocksdb", |
| 22 | + **kwargs |
| 23 | + ): |
| 24 | + # Initialize storage |
| 25 | + self.graph_storage: BaseGraphStorage = init_storage( |
| 26 | + backend=graph_backend, working_dir=working_dir, namespace="graph" |
| 27 | + ) |
| 28 | + self.chunk_storage: BaseKVStorage = init_storage( |
| 29 | + backend=kv_backend, working_dir=working_dir, namespace="chunk" |
| 30 | + ) |
17 | 31 |
|
18 | | - Returns: |
19 | | - Dictionary containing entity_accuracy and relation_accuracy metrics. |
20 | | - """ |
| 32 | + # Initialize LLM client |
| 33 | + self.llm_client: BaseLLMWrapper = init_llm("synthesizer") |
| 34 | + |
| 35 | + # Initialize individual evaluators |
| 36 | + self.accuracy_evaluator = AccuracyEvaluator( |
| 37 | + graph_storage=self.graph_storage, |
| 38 | + chunk_storage=self.chunk_storage, |
| 39 | + llm_client=self.llm_client, |
| 40 | + ) |
| 41 | + |
| 42 | + self.consistency_evaluator = ConsistencyEvaluator( |
| 43 | + graph_storage=self.graph_storage, |
| 44 | + chunk_storage=self.chunk_storage, |
| 45 | + llm_client=self.llm_client, |
| 46 | + ) |
| 47 | + |
| 48 | + # Structure evaluator doesn't need chunk_storage or llm_client |
| 49 | + structure_params = kwargs.get("structure_params", {}) |
| 50 | + self.structure_evaluator = StructureEvaluator( |
| 51 | + graph_storage=self.graph_storage, |
| 52 | + **structure_params |
| 53 | + ) |
| 54 | + |
| 55 | + logger.info("KG evaluators initialized") |
| 56 | + |
| 57 | + |
| 58 | +def evaluate_accuracy(evaluators: KGEvaluators) -> Dict[str, Any]: |
21 | 59 | logger.info("Running accuracy evaluation...") |
22 | | - results = evaluator.evaluate_accuracy() |
| 60 | + results = evaluators.accuracy_evaluator.evaluate() |
23 | 61 | logger.info("Accuracy evaluation completed") |
24 | 62 | return results |
25 | 63 |
|
26 | 64 |
|
27 | | -def evaluate_consistency(evaluator: KGQualityEvaluator) -> Dict[str, Any]: |
28 | | - """Evaluate consistency by detecting semantic conflicts. |
29 | | - |
30 | | - Args: |
31 | | - evaluator: KGQualityEvaluator instance |
32 | | - |
33 | | - Returns: |
34 | | - Dictionary containing consistency metrics including conflict_rate and conflicts. |
35 | | - """ |
| 65 | +def evaluate_consistency(evaluators: KGEvaluators) -> Dict[str, Any]: |
36 | 66 | logger.info("Running consistency evaluation...") |
37 | | - results = evaluator.evaluate_consistency() |
| 67 | + results = evaluators.consistency_evaluator.evaluate() |
38 | 68 | logger.info("Consistency evaluation completed") |
39 | 69 | return results |
40 | 70 |
|
41 | 71 |
|
42 | | -def evaluate_structure(evaluator: KGQualityEvaluator) -> Dict[str, Any]: |
43 | | - """Evaluate structural robustness of the graph. |
44 | | - |
45 | | - Args: |
46 | | - evaluator: KGQualityEvaluator instance |
47 | | - |
48 | | - Returns: |
49 | | - Dictionary containing structural metrics including noise_ratio, largest_cc_ratio, etc. |
50 | | - """ |
| 72 | +def evaluate_structure(evaluators: KGEvaluators) -> Dict[str, Any]: |
51 | 73 | logger.info("Running structural robustness evaluation...") |
52 | | - results = evaluator.evaluate_structure() |
| 74 | + results = evaluators.structure_evaluator.evaluate() |
53 | 75 | logger.info("Structural robustness evaluation completed") |
54 | 76 | return results |
55 | 77 |
|
56 | 78 |
|
57 | | -def evaluate_all(evaluator: KGQualityEvaluator) -> Dict[str, Any]: |
58 | | - """Run all evaluations (accuracy, consistency, structure). |
59 | | - |
60 | | - Args: |
61 | | - evaluator: KGQualityEvaluator instance |
62 | | - |
63 | | - Returns: |
64 | | - Dictionary containing all evaluation results with keys: accuracy, consistency, structure. |
65 | | - """ |
| 79 | +def evaluate_all(evaluators: KGEvaluators) -> Dict[str, Any]: |
66 | 80 | logger.info("Running all evaluations...") |
67 | | - results = evaluator.evaluate_all() |
| 81 | + results = { |
| 82 | + "accuracy": evaluate_accuracy(evaluators), |
| 83 | + "consistency": evaluate_consistency(evaluators), |
| 84 | + "structure": evaluate_structure(evaluators), |
| 85 | + } |
68 | 86 | logger.info("All evaluations completed") |
69 | 87 | return results |
70 | | - |
71 | | - |
|
0 commit comments