Skip to content

Commit 6d5df6a

Browse files
committed
set flag instead
1 parent 838a4a9 commit 6d5df6a

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/forge/actors/generator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,10 @@ def split_keys(keys):
272272
futures = []
273273
for i, names in enumerate(split_keys(hf_param_names)):
274274

275-
async def fetch_coro():
276-
return self.weight_fetchers.slice(procs=i).fetch.call_one(
277-
version=version, param_names=names
278-
)
275+
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
276+
version=version, param_names=names
277+
)
279278

280-
fut = asyncio.create_task(fetch_coro())
281279
futures.append(fut)
282280

283281
sub_state_dicts = [await fut for fut in futures]

src/forge/env.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,17 @@ def get_value(self) -> Any:
105105
description="Whether or not to use RDMA in TorchStore.",
106106
)
107107

108+
MONARCH_OLD_ASYNC_WORKAROUND = EnvVar(
109+
name="MONARCH_OLD_ASYNC_WORKAROUND",
110+
default=1,
111+
description=(
112+
"If enabled, monarch messages will be sent immediately even it's not"
113+
" awaited. This is needed for parallel fetching of weights, as using"
114+
" create_task creates race condition. This is a temporary workaround"
115+
" and will be removed once we have a better solution."
116+
),
117+
)
118+
108119

109120
def all_env_vars() -> list[EnvVar]:
110121
"""Retrieves all registered environment variable names."""

0 commit comments

Comments
 (0)