1- import asyncio
21import random
2+ from collections import deque
33from typing import Any , Dict , List , Optional , Set , Tuple
44
5- from tqdm . asyncio import tqdm as tqdm_async
5+ from tqdm import tqdm
66
77from graphgen .bases import BaseGraphStorage
88from graphgen .bases .datatypes import Community
@@ -51,7 +51,7 @@ def _sort_units(units: list, edge_sampling: str) -> list:
5151 raise ValueError (f"Invalid edge sampling: { edge_sampling } " )
5252 return units
5353
54- async def partition (
54+ def partition (
5555 self ,
5656 g : BaseGraphStorage ,
5757 max_units_per_community : int = 10 ,
@@ -73,21 +73,19 @@ async def partition(
7373
7474 used_n : Set [str ] = set ()
7575 used_e : Set [frozenset [str ]] = set ()
76- communities : List = []
76+ communities : List [ Community ] = []
7777
7878 all_units = self ._sort_units (all_units , unit_sampling )
7979
80- async def _grow_community (
81- seed_unit : Tuple [str , Any , dict ]
82- ) -> Optional [Community ]:
80+ def _grow_community (seed_unit : Tuple [str , Any , dict ]) -> Optional [Community ]:
8381 nonlocal used_n , used_e
8482
8583 community_nodes : Dict [str , dict ] = {}
8684 community_edges : Dict [frozenset [str ], dict ] = {}
87- queue : asyncio . Queue = asyncio . Queue ()
85+ queue = deque ()
8886 token_sum = 0
8987
90- async def _add_unit (u ):
88+ def _add_unit (u ):
9189 nonlocal token_sum
9290 t , i , d = u
9391 if t == NODE_UNIT : # node
@@ -103,19 +101,19 @@ async def _add_unit(u):
103101 token_sum += d .get ("length" , 0 )
104102 return True
105103
106- await _add_unit (seed_unit )
107- await queue .put (seed_unit )
104+ _add_unit (seed_unit )
105+ queue .append (seed_unit )
108106
109107 # BFS
110- while not queue . empty () :
108+ while queue :
111109 if (
112110 len (community_nodes ) + len (community_edges )
113111 >= max_units_per_community
114112 or token_sum >= max_tokens_per_community
115113 ):
116114 break
117115
118- cur_type , cur_id , _ = await queue .get ()
116+ cur_type , cur_id , _ = queue .popleft ()
119117
120118 neighbors : List [Tuple [str , Any , dict ]] = []
121119 if cur_type == NODE_UNIT :
@@ -136,8 +134,8 @@ async def _add_unit(u):
136134 or token_sum >= max_tokens_per_community
137135 ):
138136 break
139- if await _add_unit (nb ):
140- await queue .put (nb )
137+ if _add_unit (nb ):
138+ queue .append (nb )
141139
142140 if len (community_nodes ) + len (community_edges ) < min_units_per_community :
143141 return None
@@ -148,13 +146,13 @@ async def _add_unit(u):
148146 edges = [(u , v ) for (u , v ), _ in community_edges .items ()],
149147 )
150148
151- async for unit in tqdm_async (all_units , desc = "ECE partition" ):
149+ for unit in tqdm (all_units , desc = "ECE partition" ):
152150 utype , uid , _ = unit
153151 if (utype == NODE_UNIT and uid in used_n ) or (
154152 utype == EDGE_UNIT and uid in used_e
155153 ):
156154 continue
157- comm = await _grow_community (unit )
155+ comm = _grow_community (unit )
158156 if comm is not None :
159157 communities .append (comm )
160158
0 commit comments