Skip to content

Commit 81ad66d

Browse files
committed
Async enter
1 parent 6185edd commit 81ad66d

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

apps/grpo/main.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,13 @@ async def main():
6868
ForgeDataset,
6969
path="gsm8k",
7070
)
71+
compute_advantages = await spawn_service(
72+
default_service_cfg,
73+
ComputeAdvantages,
74+
)
75+
ref_model = await spawn_service(default_service_cfg, RefModel)
7176

77+
# ---- Core RL loops ---- #
7278
async def continuous_rollouts():
7379
while True:
7480
prompt = await dataloader.__next__.call()
@@ -77,7 +83,7 @@ async def continuous_rollouts():
7783
return
7884
version = await policy.get_current_version.choose()
7985
episode = Episode()
80-
with policy.session(version=version):
86+
async with policy.session(version=version):
8187
action = await policy.generate.call(prompt)
8288
episode.add_turn((prompt, action))
8389
episode.add_advantages(await compute_advantages.__call__.call(episode))

0 commit comments

Comments
 (0)