Skip to content

Commit a4a9399

Browse files
committed
create task to ensure parallel execution
1 parent 6a16716 commit a4a9399

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/forge/actors/generator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,16 @@ def split_keys(keys):
270270
return [keys[i::n_fetchers] for i in range(n_fetchers)]
271271

272272
futures = []
273-
for i, names in enumerate(split_keys(hf_param_names)):
274-
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
275-
version=version, param_names=names
276-
)
277-
futures.append(fut)
273+
async with asyncio.TaskGroup() as tg:
274+
for i, names in enumerate(split_keys(hf_param_names)):
275+
276+
async def fetch_coro():
277+
return self.weight_fetchers.slice(procs=i).fetch.call_one(
278+
version=version, param_names=names
279+
)
280+
281+
fut = tg.create_task(fetch_coro())
282+
futures.append(fut)
278283

279284
sub_state_dicts = [await fut for fut in futures]
280285

0 commit comments

Comments
 (0)