Skip to content

Commit f6cce9b

Browse files
feat: add kg_structure evaluation
1 parent 084cb08 commit f6cce9b

File tree

15 files changed

+312
-567
lines changed

15 files changed

+312
-567
lines changed

examples/evaluate/evaluate_kg/kg_evaluation_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ nodes:
2727
op_name: build_kg
2828
type: map_batch
2929
dependencies:
30-
- chunk_documents
30+
- chunk
3131
execution_params:
3232
replicas: 1
3333
batch_size: 128
@@ -40,6 +40,6 @@ nodes:
4040
- build_kg
4141
params:
4242
metrics:
43-
- kg_accuracy
44-
- kg_consistency
4543
- kg_structure
44+
# - kg_accuracy
45+
# - kg_consistency

examples/evaluate/evaluate_qa/qa_evaluation_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ nodes:
9292
metrics:
9393
- qa_length
9494
- qa_mtld
95-
# - qa_reward_score
96-
# - qa_uni_score
95+
- qa_reward_score
96+
- qa_uni_score
9797
mtld_params:
9898
threshold: 0.7

graphgen/bases/base_storage.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from abc import ABC, abstractmethod
12
from dataclasses import dataclass
2-
from typing import Generic, TypeVar, Union
3+
from typing import Dict, Generic, List, Set, TypeVar, Union
34

45
T = TypeVar("T")
56

@@ -45,52 +46,90 @@ def reload(self):
4546
raise NotImplementedError
4647

4748

48-
class BaseGraphStorage(StorageNameSpace):
49+
class BaseGraphStorage(StorageNameSpace, ABC):
50+
@abstractmethod
51+
def is_directed(self) -> bool:
52+
pass
53+
54+
@abstractmethod
4955
def has_node(self, node_id: str) -> bool:
5056
raise NotImplementedError
5157

58+
@abstractmethod
5259
def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
5360
raise NotImplementedError
5461

62+
@abstractmethod
5563
def node_degree(self, node_id: str) -> int:
5664
raise NotImplementedError
5765

58-
def edge_degree(self, src_id: str, tgt_id: str) -> int:
59-
raise NotImplementedError
66+
@abstractmethod
67+
def get_all_node_degrees(self) -> Dict[str, int]:
68+
pass
6069

70+
def get_isolated_nodes(self) -> List[str]:
71+
return [
72+
node_id
73+
for node_id, degree in self.get_all_node_degrees().items()
74+
if degree == 0
75+
]
76+
77+
@abstractmethod
6178
def get_node(self, node_id: str) -> Union[dict, None]:
6279
raise NotImplementedError
6380

81+
@abstractmethod
6482
def update_node(self, node_id: str, node_data: dict[str, str]):
6583
raise NotImplementedError
6684

85+
@abstractmethod
6786
def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
6887
raise NotImplementedError
6988

89+
@abstractmethod
90+
def get_node_count(self) -> int:
91+
pass
92+
93+
@abstractmethod
7094
def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
7195
raise NotImplementedError
7296

97+
@abstractmethod
7398
def update_edge(
7499
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
75100
):
76101
raise NotImplementedError
77102

103+
@abstractmethod
78104
def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
79105
raise NotImplementedError
80106

107+
@abstractmethod
108+
def get_edge_count(self) -> int:
109+
pass
110+
111+
@abstractmethod
81112
def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], None]:
82113
raise NotImplementedError
83114

115+
@abstractmethod
84116
def upsert_node(self, node_id: str, node_data: dict[str, str]):
85117
raise NotImplementedError
86118

119+
@abstractmethod
87120
def upsert_edge(
88121
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
89122
):
90123
raise NotImplementedError
91124

125+
@abstractmethod
92126
def delete_node(self, node_id: str):
93127
raise NotImplementedError
94128

129+
@abstractmethod
95130
def reload(self):
96131
raise NotImplementedError
132+
133+
@abstractmethod
134+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
135+
raise NotImplementedError

graphgen/common/init_storage.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Union
1+
from typing import Any, Dict, List, Set, Union
22

33
import ray
44

@@ -68,6 +68,21 @@ def __init__(self, backend: str, working_dir: str, namespace: str):
6868
def index_done_callback(self):
6969
return self.graph.index_done_callback()
7070

71+
def is_directed(self) -> bool:
72+
return self.graph.is_directed()
73+
74+
def get_all_node_degrees(self) -> Dict[str, int]:
75+
return self.graph.get_all_node_degrees()
76+
77+
def get_node_count(self) -> int:
78+
return self.graph.get_node_count()
79+
80+
def get_edge_count(self) -> int:
81+
return self.graph.get_edge_count()
82+
83+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
84+
return self.graph.get_connected_components(undirected)
85+
7186
def has_node(self, node_id: str) -> bool:
7287
return self.graph.has_node(node_id)
7388

@@ -165,6 +180,21 @@ def __init__(self, actor_handle: ray.actor.ActorHandle):
165180
def index_done_callback(self):
166181
return ray.get(self.actor.index_done_callback.remote())
167182

183+
def is_directed(self) -> bool:
184+
return ray.get(self.actor.is_directed.remote())
185+
186+
def get_all_node_degrees(self) -> Dict[str, int]:
187+
return ray.get(self.actor.get_all_node_degrees.remote())
188+
189+
def get_node_count(self) -> int:
190+
return ray.get(self.actor.get_node_count.remote())
191+
192+
def get_edge_count(self) -> int:
193+
return ray.get(self.actor.get_edge_count.remote())
194+
195+
def get_connected_components(self, undirected: bool = True) -> List[Set[str]]:
196+
return ray.get(self.actor.get_connected_components.remote(undirected))
197+
168198
def has_node(self, node_id: str) -> bool:
169199
return ray.get(self.actor.has_node.remote(node_id))
170200

@@ -239,10 +269,14 @@ def create_storage(backend: str, working_dir: str, namespace: str):
239269
try:
240270
actor_handle = ray.get_actor(actor_name)
241271
except ValueError:
242-
actor_handle = ray.remote(actor_class).options(
243-
name=actor_name,
244-
get_if_exists=True,
245-
).remote(backend, working_dir, namespace)
272+
actor_handle = (
273+
ray.remote(actor_class)
274+
.options(
275+
name=actor_name,
276+
get_if_exists=True,
277+
)
278+
.remote(backend, working_dir, namespace)
279+
)
246280
ray.get(actor_handle.ready.remote())
247281
return proxy_class(actor_handle)
248282

graphgen/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from .evaluator import (
2-
KGQualityEvaluator,
2+
AccuracyEvaluator,
3+
ConsistencyEvaluator,
34
LengthEvaluator,
45
MTLDEvaluator,
56
RewardEvaluator,
7+
StructureEvaluator,
68
UniEvaluator,
79
)
810
from .generator import (
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,2 @@
1+
from .kg import AccuracyEvaluator, ConsistencyEvaluator, StructureEvaluator
12
from .qa import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
2-
from .kg import (
3-
AccuracyEvaluator,
4-
ConsistencyEvaluator,
5-
KGQualityEvaluator,
6-
StructureEvaluator,
7-
)

graphgen/models/evaluator/kg/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99

1010
from .accuracy_evaluator import AccuracyEvaluator
1111
from .consistency_evaluator import ConsistencyEvaluator
12-
from .kg_quality_evaluator import KGQualityEvaluator
1312
from .structure_evaluator import StructureEvaluator
1413

1514
__all__ = [
1615
"AccuracyEvaluator",
1716
"ConsistencyEvaluator",
18-
"KGQualityEvaluator",
1917
"StructureEvaluator",
2018
]

graphgen/models/evaluator/kg/kg_quality_evaluator.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)