Skip to content

Commit 6185edd

Browse files
committed
Add refmodel stub
1 parent 2a1546b commit 6185edd

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

apps/grpo/main.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,25 @@ async def update_weights(self):
1919
pass
2020

2121

22-
async def generate_rollout():
23-
pass
22+
class Episode:
23+
24+
turns = []
25+
26+
def add_turn(self, turn):
27+
self.turns.append(turn)
28+
29+
def add_transform_info(self, key, data):
30+
setattr(self, key, data)
31+
32+
33+
class ComputeAdvantages(Actor):
34+
def __call__(self, episode):
35+
pass
36+
37+
38+
class RefModel(Actor):
39+
def forward(self, x):
40+
pass
2441

2542

2643
async def main():
@@ -46,11 +63,25 @@ async def main():
4663
batch_size=4,
4764
max_policy_age=1,
4865
)
66+
dataloader = await spawn_service(
67+
default_service_cfg,
68+
ForgeDataset,
69+
path="gsm8k",
70+
)
4971

5072
async def continuous_rollouts():
5173
while True:
74+
prompt = await dataloader.__next__.call()
75+
if prompt is None:
76+
print(f"Dataloader is empty, exiting rollout creation")
77+
return
5278
version = await policy.get_current_version.choose()
53-
episode = await generate_rollout(version)
79+
episode = Episode()
80+
with policy.session(version=version):
81+
action = await policy.generate.call(prompt)
82+
episode.add_turn((prompt, action))
83+
episode.add_advantages(await compute_advantages.__call__.call(episode))
84+
episode.add_logprobs(await ref_model.forward.call(episode.get_tokens()))
5485
await replay_buffer.add.call(episode)
5586

5687
rollout_task = asyncio.create_task(continuous_rollouts())

0 commit comments

Comments
 (0)