Skip to content

Commit 3e3b0ff

Browse files
Merge pull request #59 from open-sciencelab/partitioner
refactor: Partitioner & Generator
2 parents 8eba85a + fed4baa commit 3e3b0ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1775
-1478
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ max-public-methods=20
308308
max-returns=6
309309

310310
# Maximum number of statements in function / method body.
311-
max-statements=50
311+
max-statements=60
312312

313313
# Minimum number of public methods for a class (see R0903).
314314
min-public-methods=2

graphgen/bases/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from .base_generator import BaseGenerator
12
from .base_kg_builder import BaseKGBuilder
23
from .base_llm_client import BaseLLMClient
4+
from .base_partitioner import BasePartitioner
35
from .base_reader import BaseReader
46
from .base_splitter import BaseSplitter
57
from .base_storage import (

graphgen/bases/base_generator.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
from graphgen.bases.base_llm_client import BaseLLMClient
6+
7+
8+
@dataclass
9+
class BaseGenerator(ABC):
10+
"""
11+
Generate QAs based on given prompts.
12+
"""
13+
14+
llm_client: BaseLLMClient
15+
16+
@staticmethod
17+
@abstractmethod
18+
def build_prompt(
19+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
20+
) -> str:
21+
"""Build prompt for LLM based on the given batch"""
22+
23+
@staticmethod
24+
@abstractmethod
25+
def parse_response(response: str) -> Any:
26+
"""Parse the LLM response and return the generated QAs"""
27+
28+
async def generate(
29+
self,
30+
batch: tuple[
31+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
32+
],
33+
) -> dict[str, Any]:
34+
"""
35+
Generate QAs based on a given batch.
36+
:param batch
37+
:return: QA pairs
38+
"""
39+
result = {}
40+
prompt = self.build_prompt(batch)
41+
response = await self.llm_client.generate_answer(prompt)
42+
qa_pairs = self.parse_response(response) # generate one or more QA pairs
43+
result.update(qa_pairs)
44+
return result
45+
46+
@staticmethod
47+
def format_generation_results(
48+
results: list[dict], output_data_format: str
49+
) -> list[dict[str, Any]]:
50+
if output_data_format == "Alpaca":
51+
results = [
52+
{
53+
"instruction": v["question"],
54+
"input": "",
55+
"output": v["answer"],
56+
}
57+
for item in results
58+
for k, v in item.items()
59+
]
60+
elif output_data_format == "Sharegpt":
61+
results = [
62+
{
63+
"conversations": [
64+
{"from": "human", "value": v["question"]},
65+
{"from": "gpt", "value": v["answer"]},
66+
]
67+
}
68+
for item in results
69+
for k, v in item.items()
70+
]
71+
elif output_data_format == "ChatML":
72+
results = [
73+
{
74+
"messages": [
75+
{"role": "user", "content": v["question"]},
76+
{"role": "assistant", "content": v["answer"]},
77+
]
78+
}
79+
for item in results
80+
for k, v in item.items()
81+
]
82+
else:
83+
raise ValueError(f"Unknown output data format: {output_data_format}")
84+
return results

graphgen/bases/base_partitioner.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any, List
4+
5+
from graphgen.bases.base_storage import BaseGraphStorage
6+
from graphgen.bases.datatypes import Community
7+
8+
9+
@dataclass
10+
class BasePartitioner(ABC):
11+
@abstractmethod
12+
async def partition(
13+
self,
14+
g: BaseGraphStorage,
15+
**kwargs: Any,
16+
) -> List[Community]:
17+
"""
18+
Graph -> Communities
19+
:param g: Graph storage instance
20+
:param kwargs: Additional parameters for partitioning
21+
:return: List of communities
22+
"""
23+
24+
@staticmethod
25+
async def community2batch(
26+
communities: List[Community], g: BaseGraphStorage
27+
) -> list[
28+
tuple[
29+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
30+
]
31+
]:
32+
"""
33+
Convert communities to batches of nodes and edges.
34+
:param communities
35+
:param g: Graph storage instance
36+
:return: List of batches, each batch is a tuple of (nodes, edges)
37+
"""
38+
batches = []
39+
for comm in communities:
40+
nodes = comm.nodes
41+
edges = comm.edges
42+
nodes_data = []
43+
for node in nodes:
44+
node_data = await g.get_node(node)
45+
if node_data:
46+
nodes_data.append((node, node_data))
47+
edges_data = []
48+
for u, v in edges:
49+
edge_data = await g.get_edge(u, v)
50+
if edge_data:
51+
edges_data.append((u, v, edge_data))
52+
else:
53+
edge_data = await g.get_edge(v, u)
54+
if edge_data:
55+
edges_data.append((v, u, edge_data))
56+
batches.append((nodes_data, edges_data))
57+
return batches
58+
59+
@staticmethod
60+
def _build_adjacency_list(
61+
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
62+
) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
63+
"""
64+
Build adjacency list and edge set from nodes and edges.
65+
:param nodes
66+
:param edges
67+
:return: adjacency list, edge set
68+
"""
69+
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
70+
edge_set: set[tuple[str, str]] = set()
71+
for e in edges:
72+
adj[e[0]].append(e[1])
73+
adj[e[1]].append(e[0])
74+
edge_set.add((e[0], e[1]))
75+
edge_set.add((e[1], e[0]))
76+
return adj, edge_set

graphgen/bases/base_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async def get_node(self, node_id: str) -> Union[dict, None]:
7878
async def update_node(self, node_id: str, node_data: dict[str, str]):
7979
raise NotImplementedError
8080

81-
async def get_all_nodes(self) -> Union[list[dict], None]:
81+
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
8282
raise NotImplementedError
8383

8484
async def get_edge(
@@ -91,7 +91,7 @@ async def update_edge(
9191
):
9292
raise NotImplementedError
9393

94-
async def get_all_edges(self) -> Union[list[dict], None]:
94+
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
9595
raise NotImplementedError
9696

9797
async def get_node_edges(

graphgen/bases/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ class Token:
3030
@property
3131
def logprob(self) -> float:
3232
return math.log(self.prob)
33+
34+
35+
@dataclass
36+
class Community:
37+
id: Union[int, str]
38+
nodes: List[str] = field(default_factory=list)
39+
edges: List[tuple] = field(default_factory=list)
40+
metadata: dict = field(default_factory=dict)

graphgen/configs/aggregated_config.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1313
partition: # graph partition configuration
1414
method: ece # ece is a custom partition method based on comprehension loss
1515
method_params:
16-
bidirectional: true # whether to traverse the graph in both directions
17-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18-
expand_method: max_width # expand method, support: max_width, max_depth
19-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 5 # maximum depth for graph traversal
21-
max_extra_edges: 20 # max edges per direction (if expand_method="max_width")
22-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
16+
max_units_per_community: 20 # max nodes and edges per community
17+
min_units_per_community: 5 # min nodes and edges per community
18+
max_tokens_per_community: 10240 # max tokens per community
19+
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
2420
generate:
2521
mode: aggregated # atomic, aggregated, multi_hop, cot
2622
data_format: ChatML # Alpaca, Sharegpt, ChatML

graphgen/configs/atomic_config.yaml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1111
quiz_samples: 2 # number of quiz samples to generate
1212
re_judge: false # whether to re-judge the existing quiz samples
1313
partition: # graph partition configuration
14-
method: ece # ece is a custom partition method based on comprehension loss
14+
method: dfs # partition method, support: dfs, bfs, ece, leiden
1515
method_params:
16-
bidirectional: true # whether to traverse the graph in both directions
17-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18-
expand_method: max_width # expand method, support: max_width, max_depth
19-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 3 # maximum depth for graph traversal
21-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
22-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
16+
max_units_per_community: 1 # atomic partition, one node or edge per community
2417
generate:
2518
mode: atomic # atomic, aggregated, multi_hop, cot
2619
data_format: Alpaca # Alpaca, Sharegpt, ChatML

graphgen/configs/cot_config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ search: # web search configuration
99
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1010
enabled: false
1111
partition: # graph partition configuration
12-
method: leiden # leiden is a community detection algorithm
12+
method: leiden # leiden is a partitioner detection algorithm
1313
method_params:
1414
max_size: 20 # Maximum size of communities
15-
use_lcc: false
16-
random_seed: 42
15+
use_lcc: false # whether to use the largest connected component
16+
random_seed: 42 # random seed for partitioning
1717
generate:
1818
mode: cot # atomic, aggregated, multi_hop, cot
1919
data_format: Sharegpt # Alpaca, Sharegpt, ChatML

graphgen/configs/multi_hop_config.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,10 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1313
partition: # graph partition configuration
1414
method: ece # ece is a custom partition method based on comprehension loss
1515
method_params:
16-
bidirectional: true # whether to traverse the graph in both directions
17-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18-
expand_method: max_width # expand method, support: max_width, max_depth
19-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 1 # maximum depth for graph traversal
21-
max_extra_edges: 2 # max edges per direction (if expand_method="max_width")
22-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
16+
max_units_per_community: 3 # max nodes and edges per community, for multi-hop, we recommend setting it to 3
17+
min_units_per_community: 3 # min nodes and edges per community, for multi-hop, we recommend setting it to 3
18+
max_tokens_per_community: 10240 # max tokens per community
19+
unit_sampling: random # edge sampling strategy, support: random, max_loss, min_loss
2420
generate:
2521
mode: multi_hop # strategy for generating multi-hop QA pairs
2622
data_format: ChatML # Alpaca, Sharegpt, ChatML

0 commit comments

Comments
 (0)