2020    get_dcp_whole_state_dict_key ,
2121    get_param_prefix ,
2222)
23- from  forge .actors .policy  import  Policy 
23+ from  forge .actors .generator  import  Generator 
2424from  forge .actors .reference_model  import  ReferenceModel 
2525from  forge .actors .replay_buffer  import  ReplayBuffer 
2626from  forge .actors .trainer  import  RLTrainer 
2727from  forge .cli .config  import  parse 
2828from  forge .controller .actor  import  ForgeActor 
2929from  forge .controller .provisioner  import  init_provisioner , shutdown 
3030from  forge .data .rewards  import  MathReward , ThinkingReward 
31+ from  forge .data_models .completion  import  Completion 
3132from  forge .observability .metric_actors  import  get_or_create_metric_logger 
3233from  forge .observability .metrics  import  record_metric , Reduce 
3334from  forge .observability .perf_tracker  import  Tracer 
4142
4243@dataclass  
4344class  Episode :
44-     # TODO: add adtional layer for multi-turn 
4545    episode_id : str 
46-     request : str 
47-     policy_version : int 
4846    pad_id : int 
4947    request_len : int 
5048    response_len : int 
5149    target : Any  |  None  =  None 
52-     # processed data 
53-     response : str  |  None  =  None 
54-     request_tokens : list [int ] |  None  =  None 
55-     response_tokens : list [int ] |  None  =  None 
50+     # Processed data 
51+     completion : Completion  |  None  =  None 
5652    ref_logprobs : torch .Tensor  |  None  =  None 
5753    reward : float  |  None  =  None 
5854    advantage : float  |  None  =  None 
5955
6056    @property  
61-     def  request_tensor (self ):
62-         tensor  =  torch .tensor (self .request_tokens , dtype = torch .long )
57+     def  policy_version (self ) ->  int  |  None :
58+         return  self .completion .generator_version 
59+ 
60+     @property  
61+     def  request_tensor (self ) ->  torch .Tensor :
62+         request_tokens : torch .Tensor  =  self .completion .prompt_ids 
63+         tensor  =  torch .tensor (request_tokens , dtype = torch .long )
6364        if  tensor .shape [0 ] <  self .request_len :  # left pad 
6465            diff  =  self .request_len  -  tensor .shape [0 ]
6566            tensor  =  F .pad (tensor , (diff , 0 ), value = self .pad_id )
6667        return  tensor 
6768
6869    @property  
69-     def  response_tensor (self ):
70-         tensor  =  torch .tensor (self .response_tokens , dtype = torch .long )
70+     def  response_tensor (self ) ->  torch .Tensor :
71+         response_tokens : torch .Tensor  =  self .completion .token_ids 
72+         tensor  =  torch .tensor (response_tokens , dtype = torch .long )
7173        if  tensor .shape [0 ] <  self .response_len :  # right pad 
7274            diff  =  self .response_len  -  tensor .shape [0 ]
7375            tensor  =  F .pad (tensor , (0 , diff ), value = self .pad_id )
7476        return  tensor 
7577
7678
77- @dataclass  
78- class  Group :
79-     group_id : str 
80-     episodes : list [Episode ]
81- 
82-     @classmethod  
83-     def  new_group (
84-         cls ,
85-         group_id : int ,
86-         group_size : int ,
87-         request : str ,
88-         policy_version : int ,
89-         pad_id : int ,
90-         request_len : int ,
91-         response_len : int ,
92-         target : Any  =  None ,
93-     ):
94-         episodes  =  []
95-         for  _  in  range (group_size ):
96-             episodes .append (
97-                 Episode (
98-                     episode_id = str (uuid .uuid4 ()),
99-                     request = request ,
100-                     policy_version = policy_version ,
101-                     pad_id = pad_id ,
102-                     request_len = request_len ,
103-                     response_len = response_len ,
104-                     target = target ,
105-                 )
106-             )
107-         return  cls (str (group_id ), episodes )
79+ # Represents the group (G) of episodes in GRPO 
80+ Group  =  list [Episode ]
81+ 
82+ # Represents the Policy Model to collect data from 
83+ Policy  =  Generator 
10884
10985
110- def  collate (batches : list [list [Episode ]]):
86+ def  collate (
87+     batches : list [Group ],
88+ ) ->  tuple [list [dict [str , Any ]], list [dict [str , Any ]]]:
89+     """ 
90+     Collates a list of batches into a single batch of inputs and targets. 
91+     Each batch is a list of episodes, and each episode is a dict of tensors. 
92+     """ 
11193    inputs  =  []
11294    targets  =  []
11395    for  batch  in  batches :
@@ -221,7 +203,7 @@ class ComputeAdvantages(ForgeActor):
221203    @endpoint  
222204    async  def  compute (self , group : Group ) ->  list [float ]:
223205        # TODO: add batch processing 
224-         rewards  =  torch .tensor ([[e .reward  for  e  in  group . episodes ]])
206+         rewards  =  torch .tensor ([[e .reward  for  e  in  group ]])
225207        mean  =  rewards .mean (1 , keepdim = True )
226208        std  =  rewards .std (1 , keepdim = True )
227209        advantages  =  (rewards  -  mean ) /  (std  +  1e-4 )
@@ -383,44 +365,32 @@ async def continuous_rollouts():
383365            t .step ("data_loading" )
384366
385367            prompt , target  =  sample ["request" ], sample ["target" ]
386-             responses  =  await  policy .generate .route (prompt )
387-             # TODO: this shall be part of the responses metadata instead of a separate call 
388-             version  =  await  policy .get_version .route ()
389- 
368+             responses : list [Completion ] =  await  policy .generate .route (prompt )
390369            t .step ("policy_generation" )
391370
392-             assert  (
393-                 len (responses ) >  0 
394-             ), "Sanity check: Responses should NEVER return empty" 
395-             assert  (
396-                 version  :=  responses [0 ].generator_version 
397-             ) is  not   None , "Response must indicate a version" 
398-             group  =  Group .new_group (
399-                 group_id = rollout_count ,
400-                 group_size = group_size ,
401-                 request = prompt ,
402-                 policy_version = version ,
403-                 pad_id = pad_id ,
404-                 request_len = max_req_tokens ,
405-                 response_len = max_res_tokens ,
406-                 target = target ,
407-             )
408- 
371+             # Construct episodes and calculate rewards 
372+             episodes  =  []
409373            input_ids  =  torch .ones (
410374                (group_size , max_req_tokens  +  max_res_tokens ),
411375                dtype = torch .long ,
412-                 device = "cuda" ,
413376            )
414-             # Populate episode info and calculate rewards 
415-             for  i , (episode , response ) in  enumerate (zip (group .episodes , responses )):
416-                 episode .request_tokens  =  response .prompt_ids 
417-                 episode .response_tokens  =  response .token_ids 
418-                 episode .response  =  response .text 
419-                 input_ids [i , :max_req_tokens ] =  episode .request_tensor 
420-                 input_ids [i , max_req_tokens :] =  episode .response_tensor 
377+             for  i , response  in  enumerate (responses ):
378+                 episode  =  Episode (
379+                     episode_id = str (uuid .uuid4 ()),
380+                     pad_id = pad_id ,
381+                     request_len = max_req_tokens ,
382+                     response_len = max_res_tokens ,
383+                     target = target ,
384+                     completion = response ,
385+                 )
421386                episode .reward  =  await  reward_actor .evaluate_response .route (
422387                    prompt = prompt , response = response .text , target = target 
423388                )
389+                 episodes .append (episode )
390+ 
391+                 # Build input_ids for reference logprobs 
392+                 input_ids [i , :max_req_tokens ] =  episode .request_tensor 
393+                 input_ids [i , max_req_tokens :] =  episode .response_tensor 
424394
425395            t .step ("reward_evaluation" )
426396
@@ -429,14 +399,13 @@ async def continuous_rollouts():
429399            )
430400            t .step ("reference_model_calculate_logprobs" )
431401
432-             for  i , episode  in  enumerate (group . episodes ):
402+             for  i , episode  in  enumerate (episodes ):
433403                episode .ref_logprobs  =  ref_logprobs [i ]
434404            del  ref_logprobs , input_ids 
435-             t .step ("compute_logprobs" )
436405
437406            # Calculate advantages and add to replay buffer 
438-             advantages  =  await  compute_advantages .compute .call_one (group )
439-             for  episode , advantage  in  zip (group . episodes , advantages ):
407+             advantages  =  await  compute_advantages .compute .call_one (episodes )
408+             for  episode , advantage  in  zip (episodes , advantages ):
440409                episode .advantage  =  advantage 
441410                await  replay_buffer .add .call_one (episode )
442411
@@ -524,22 +493,6 @@ async def continuous_training():
524493
525494        training_task .cancel ()
526495
527-         # give mlogger time to shutdown backends, otherwise they can stay running. 
528-         # TODO (felipemello) find more elegant solution 
529-         await  mlogger .shutdown .call_one ()
530-         await  asyncio .sleep (2 )
531- 
532-         await  asyncio .gather (
533-             DatasetActor .shutdown (dataloader ),
534-             policy .shutdown (),
535-             RLTrainer .shutdown (trainer ),
536-             ReplayBuffer .shutdown (replay_buffer ),
537-             ComputeAdvantages .shutdown (compute_advantages ),
538-             ref_model .shutdown (),
539-             reward_actor .shutdown (),
540-         )
541-         # TODO - add a global shutdown that implicitly shuts down all services 
542-         # and remote allocations 
543496        await  shutdown ()
544497
545498
0 commit comments