@@ -19,8 +19,25 @@ async def update_weights(self):
1919 pass
2020
2121
22- async def generate_rollout ():
23- pass
22+ class Episode :
23+
24+ turns = []
25+
26+ def add_turn (self , turn ):
27+ self .turns .append (turn )
28+
29+ def add_transform_info (self , key , data ):
30+ setattr (self , key , data )
31+
32+
33+ class ComputeAdvantages (Actor ):
34+ def __call__ (self , episode ):
35+ pass
36+
37+
38+ class RefModel (Actor ):
39+ def forward (self , x ):
40+ pass
2441
2542
2643async def main ():
@@ -46,11 +63,25 @@ async def main():
4663 batch_size = 4 ,
4764 max_policy_age = 1 ,
4865 )
66+ dataloader = await spawn_service (
67+ default_service_cfg ,
68+ ForgeDataset ,
69+ path = "gsm8k" ,
70+ )
4971
5072 async def continuous_rollouts ():
5173 while True :
74+ prompt = await dataloader .__next__ .call ()
75+ if prompt is None :
76+ print (f"Dataloader is empty, exiting rollout creation" )
77+ return
5278 version = await policy .get_current_version .choose ()
53- episode = await generate_rollout (version )
79+ episode = Episode ()
80+ with policy .session (version = version ):
81+ action = await policy .generate .call (prompt )
82+ episode .add_turn ((prompt , action ))
83+ episode .add_advantages (await compute_advantages .__call__ .call (episode ))
84+ episode .add_logprobs (await ref_model .forward .call (episode .get_tokens ()))
5485 await replay_buffer .add .call (episode )
5586
5687 rollout_task = asyncio .create_task (continuous_rollouts ())
0 commit comments