Skip to content

Commit 7643b9f

Browse files
fix: return generator in ece_partitioner
1 parent aab7438 commit 7643b9f

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Any, Dict, List, Optional, Set, Tuple
3+
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
44

55
from tqdm import tqdm
66

@@ -59,7 +59,7 @@ def partition(
5959
max_tokens_per_community: int = 10240,
6060
unit_sampling: str = "random",
6161
**kwargs: Any,
62-
) -> List[Community]:
62+
) -> Iterable[Community]:
6363
nodes: List[Tuple[str, dict]] = g.get_all_nodes()
6464
edges: List[Tuple[str, str, dict]] = g.get_all_edges()
6565

@@ -73,7 +73,6 @@ def partition(
7373

7474
used_n: Set[str] = set()
7575
used_e: Set[frozenset[str]] = set()
76-
communities: List[Community] = []
7776

7877
all_units = self._sort_units(all_units, unit_sampling)
7978

@@ -141,7 +140,7 @@ def _add_unit(u):
141140
return None
142141

143142
return Community(
144-
id=len(communities),
143+
id=seed_unit[1],
145144
nodes=list(community_nodes.keys()),
146145
edges=[(u, v) for (u, v), _ in community_edges.items()],
147146
)
@@ -153,7 +152,5 @@ def _add_unit(u):
153152
):
154153
continue
155154
comm = _grow_community(unit)
156-
if comm is not None:
157-
communities.append(comm)
158-
159-
return communities
155+
if comm:
156+
yield comm

0 commit comments

Comments
 (0)