Skip to content

Commit a6aedaf

Browse files
refactor: refactor partition to accomodate ray data
1 parent d7d6c2a commit a6aedaf

File tree

8 files changed

+203
-242
lines changed

8 files changed

+203
-242
lines changed

graphgen/bases/base_partitioner.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class BasePartitioner(ABC):
99
@abstractmethod
10-
async def partition(
10+
def partition(
1111
self,
1212
g: BaseGraphStorage,
1313
**kwargs: Any,
@@ -20,39 +20,34 @@ async def partition(
2020
"""
2121

2222
@staticmethod
23-
async def community2batch(
24-
communities: List[Community], g: BaseGraphStorage
25-
) -> list[
26-
tuple[
27-
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
28-
]
23+
def community2batch(
24+
comm: Community, g: BaseGraphStorage
25+
) -> tuple[
26+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
2927
]:
3028
"""
3129
Convert communities to batches of nodes and edges.
32-
:param communities
30+
:param comm: Community
3331
:param g: Graph storage instance
3432
:return: List of batches, each batch is a tuple of (nodes, edges)
3533
"""
36-
batches = []
37-
for comm in communities:
38-
nodes = comm.nodes
39-
edges = comm.edges
40-
nodes_data = []
41-
for node in nodes:
42-
node_data = g.get_node(node)
43-
if node_data:
44-
nodes_data.append((node, node_data))
45-
edges_data = []
46-
for u, v in edges:
47-
edge_data = g.get_edge(u, v)
34+
nodes = comm.nodes
35+
edges = comm.edges
36+
nodes_data = []
37+
for node in nodes:
38+
node_data = g.get_node(node)
39+
if node_data:
40+
nodes_data.append((node, node_data))
41+
edges_data = []
42+
for u, v in edges:
43+
edge_data = g.get_edge(u, v)
44+
if edge_data:
45+
edges_data.append((u, v, edge_data))
46+
else:
47+
edge_data = g.get_edge(v, u)
4848
if edge_data:
49-
edges_data.append((u, v, edge_data))
50-
else:
51-
edge_data = g.get_edge(v, u)
52-
if edge_data:
53-
edges_data.append((v, u, edge_data))
54-
batches.append((nodes_data, edges_data))
55-
return batches
49+
edges_data.append((v, u, edge_data))
50+
return nodes_data, edges_data
5651

5752
@staticmethod
5853
def _build_adjacency_list(

graphgen/models/partitioner/anchor_bfs_partitioner.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Any, List, Literal, Set, Tuple
3+
from typing import Any, Iterable, List, Literal, Set, Tuple
44

55
from graphgen.bases import BaseGraphStorage
66
from graphgen.bases.datatypes import Community
@@ -30,42 +30,37 @@ def __init__(
3030
self.anchor_type = anchor_type
3131
self.anchor_ids = anchor_ids
3232

33-
async def partition(
33+
def partition(
3434
self,
3535
g: BaseGraphStorage,
3636
max_units_per_community: int = 1,
3737
**kwargs: Any,
38-
) -> List[Community]:
38+
) -> Iterable[Community]:
3939
nodes = g.get_all_nodes() # List[tuple[id, meta]]
4040
edges = g.get_all_edges() # List[tuple[u, v, meta]]
4141

4242
adj, _ = self._build_adjacency_list(nodes, edges)
4343

44-
anchors: Set[str] = await self._pick_anchor_ids(nodes)
44+
anchors: Set[str] = self._pick_anchor_ids(nodes)
4545
if not anchors:
46-
return [] # if no anchors, return empty list
46+
return # if no anchors, return nothing
4747

4848
used_n: set[str] = set()
4949
used_e: set[frozenset[str]] = set()
50-
communities: List[Community] = []
5150

5251
seeds = list(anchors)
5352
random.shuffle(seeds)
5453

5554
for seed_node in seeds:
5655
if seed_node in used_n:
5756
continue
58-
comm_n, comm_e = await self._grow_community(
57+
comm_n, comm_e = self._grow_community(
5958
seed_node, adj, max_units_per_community, used_n, used_e
6059
)
6160
if comm_n or comm_e:
62-
communities.append(
63-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
64-
)
61+
yield Community(id=seed_node, nodes=comm_n, edges=comm_e)
6562

66-
return communities
67-
68-
async def _pick_anchor_ids(
63+
def _pick_anchor_ids(
6964
self,
7065
nodes: List[tuple[str, dict]],
7166
) -> Set[str]:
@@ -80,7 +75,7 @@ async def _pick_anchor_ids(
8075
return anchor_ids
8176

8277
@staticmethod
83-
async def _grow_community(
78+
def _grow_community(
8479
seed: str,
8580
adj: dict[str, List[str]],
8681
max_units: int,

graphgen/models/partitioner/bfs_partitioner.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Any, List
3+
from typing import Any, Iterable, List
44

55
from graphgen.bases import BaseGraphStorage, BasePartitioner
66
from graphgen.bases.datatypes import Community
@@ -17,20 +17,19 @@ class BFSPartitioner(BasePartitioner):
1717
(A unit is a node or an edge.)
1818
"""
1919

20-
async def partition(
20+
def partition(
2121
self,
2222
g: BaseGraphStorage,
2323
max_units_per_community: int = 1,
2424
**kwargs: Any,
25-
) -> List[Community]:
25+
) -> Iterable[Community]:
2626
nodes = g.get_all_nodes()
2727
edges = g.get_all_edges()
2828

2929
adj, _ = self._build_adjacency_list(nodes, edges)
3030

3131
used_n: set[str] = set()
3232
used_e: set[frozenset[str]] = set()
33-
communities: List[Community] = []
3433

3534
units = [(NODE_UNIT, n[0]) for n in nodes] + [
3635
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
@@ -74,8 +73,4 @@ async def partition(
7473
queue.append((NODE_UNIT, n))
7574

7675
if comm_n or comm_e:
77-
communities.append(
78-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
79-
)
80-
81-
return communities
76+
yield Community(id=seed, nodes=comm_n, edges=comm_e)

graphgen/models/partitioner/dfs_partitioner.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from collections.abc import Iterable
23
from typing import Any, List
34

45
from graphgen.bases import BaseGraphStorage, BasePartitioner
@@ -16,20 +17,19 @@ class DFSPartitioner(BasePartitioner):
1617
(In GraphGen, a unit is defined as a node or an edge.)
1718
"""
1819

19-
async def partition(
20+
def partition(
2021
self,
2122
g: BaseGraphStorage,
2223
max_units_per_community: int = 1,
2324
**kwargs: Any,
24-
) -> List[Community]:
25+
) -> Iterable[Community]:
2526
nodes = g.get_all_nodes()
2627
edges = g.get_all_edges()
2728

2829
adj, _ = self._build_adjacency_list(nodes, edges)
2930

3031
used_n: set[str] = set()
3132
used_e: set[frozenset[str]] = set()
32-
communities: List[Community] = []
3333

3434
units = [(NODE_UNIT, n[0]) for n in nodes] + [
3535
(EDGE_UNIT, frozenset((u, v))) for u, v, _ in edges
@@ -71,8 +71,4 @@ async def partition(
7171
stack.append((NODE_UNIT, n))
7272

7373
if comm_n or comm_e:
74-
communities.append(
75-
Community(id=len(communities), nodes=comm_n, edges=comm_e)
76-
)
77-
78-
return communities
74+
yield Community(id=seed, nodes=comm_n, edges=comm_e)

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import asyncio
21
import random
2+
from collections import deque
33
from typing import Any, Dict, List, Optional, Set, Tuple
44

5-
from tqdm.asyncio import tqdm as tqdm_async
5+
from tqdm import tqdm
66

77
from graphgen.bases import BaseGraphStorage
88
from graphgen.bases.datatypes import Community
@@ -51,7 +51,7 @@ def _sort_units(units: list, edge_sampling: str) -> list:
5151
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
5252
return units
5353

54-
async def partition(
54+
def partition(
5555
self,
5656
g: BaseGraphStorage,
5757
max_units_per_community: int = 10,
@@ -73,21 +73,19 @@ async def partition(
7373

7474
used_n: Set[str] = set()
7575
used_e: Set[frozenset[str]] = set()
76-
communities: List = []
76+
communities: List[Community] = []
7777

7878
all_units = self._sort_units(all_units, unit_sampling)
7979

80-
async def _grow_community(
81-
seed_unit: Tuple[str, Any, dict]
82-
) -> Optional[Community]:
80+
def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Optional[Community]:
8381
nonlocal used_n, used_e
8482

8583
community_nodes: Dict[str, dict] = {}
8684
community_edges: Dict[frozenset[str], dict] = {}
87-
queue: asyncio.Queue = asyncio.Queue()
85+
queue = deque()
8886
token_sum = 0
8987

90-
async def _add_unit(u):
88+
def _add_unit(u):
9189
nonlocal token_sum
9290
t, i, d = u
9391
if t == NODE_UNIT: # node
@@ -103,19 +101,19 @@ async def _add_unit(u):
103101
token_sum += d.get("length", 0)
104102
return True
105103

106-
await _add_unit(seed_unit)
107-
await queue.put(seed_unit)
104+
_add_unit(seed_unit)
105+
queue.append(seed_unit)
108106

109107
# BFS
110-
while not queue.empty():
108+
while queue:
111109
if (
112110
len(community_nodes) + len(community_edges)
113111
>= max_units_per_community
114112
or token_sum >= max_tokens_per_community
115113
):
116114
break
117115

118-
cur_type, cur_id, _ = await queue.get()
116+
cur_type, cur_id, _ = queue.popleft()
119117

120118
neighbors: List[Tuple[str, Any, dict]] = []
121119
if cur_type == NODE_UNIT:
@@ -136,8 +134,8 @@ async def _add_unit(u):
136134
or token_sum >= max_tokens_per_community
137135
):
138136
break
139-
if await _add_unit(nb):
140-
await queue.put(nb)
137+
if _add_unit(nb):
138+
queue.append(nb)
141139

142140
if len(community_nodes) + len(community_edges) < min_units_per_community:
143141
return None
@@ -148,13 +146,13 @@ async def _add_unit(u):
148146
edges=[(u, v) for (u, v), _ in community_edges.items()],
149147
)
150148

151-
async for unit in tqdm_async(all_units, desc="ECE partition"):
149+
for unit in tqdm(all_units, desc="ECE partition"):
152150
utype, uid, _ = unit
153151
if (utype == NODE_UNIT and uid in used_n) or (
154152
utype == EDGE_UNIT and uid in used_e
155153
):
156154
continue
157-
comm = await _grow_community(unit)
155+
comm = _grow_community(unit)
158156
if comm is not None:
159157
communities.append(comm)
160158

graphgen/models/partitioner/leiden_partitioner.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class LeidenPartitioner(BasePartitioner):
1313
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
1414
"""
1515

16-
async def partition(
16+
def partition(
1717
self,
1818
g: BaseGraphStorage,
1919
max_size: int = 20,
@@ -37,12 +37,10 @@ async def partition(
3737
nodes = g.get_all_nodes() # List[Tuple[str, dict]]
3838
edges = g.get_all_edges() # List[Tuple[str, str, dict]]
3939

40-
node2cid: Dict[str, int] = await self._run_leiden(
41-
nodes, edges, use_lcc, random_seed
42-
)
40+
node2cid: Dict[str, int] = self._run_leiden(nodes, edges, use_lcc, random_seed)
4341

4442
if max_size is not None and max_size > 0:
45-
node2cid = await self._split_communities(node2cid, max_size)
43+
node2cid = self._split_communities(node2cid, max_size)
4644

4745
cid2nodes: Dict[int, List[str]] = defaultdict(list)
4846
for n, cid in node2cid.items():
@@ -58,7 +56,7 @@ async def partition(
5856
return communities
5957

6058
@staticmethod
61-
async def _run_leiden(
59+
def _run_leiden(
6260
nodes: List[Tuple[str, dict]],
6361
edges: List[Tuple[str, str, dict]],
6462
use_lcc: bool = False,
@@ -92,9 +90,7 @@ async def _run_leiden(
9290
return node2cid
9391

9492
@staticmethod
95-
async def _split_communities(
96-
node2cid: Dict[str, int], max_size: int
97-
) -> Dict[str, int]:
93+
def _split_communities(node2cid: Dict[str, int], max_size: int) -> Dict[str, int]:
9894
"""
9995
Split communities larger than max_size into smaller sub-communities.
10096
"""
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .partition_kg import partition_kg
1+
from .partition_service import PartitionService

0 commit comments

Comments
 (0)