@@ -49,24 +49,21 @@ def compute_logprobs(
4949
5050class SimpleGRPOLoss (nn .Module ):
5151 """Simplified GRPO Loss for simplified single step updates
52- Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py.
52+ Inspired by the Hugging Face TRL implementation:
53+ https://github.com/huggingface/trl/blob/417915a3e4d3e3bc8d7b196594308b8eabf928be/trl/trainer/grpo_trainer.py#L1624.
5354 """
5455
5556 def __init__ (self , beta : float = 0.1 ):
5657 super ().__init__ ()
5758 self .beta = beta
5859
5960 def forward (self , logprobs , ref_logprobs , advantages , padding_mask ):
60- per_token_kl = (
61- torch .exp (ref_logprobs .detach () - logprobs )
62- - (ref_logprobs .detach () - logprobs )
63- - 1
64- )
61+ kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
6562 per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
66- per_token_loss = - (per_token_policy_loss - self .beta * per_token_kl )
63+ per_token_loss = - (per_token_policy_loss - self .beta * kl )
6764 loss = (
68- (per_token_loss * padding_mask ).sum (dim = 1 )
69- / (padding_mask .sum (dim = 1 ) + 1e-8 )
65+ (( per_token_loss * padding_mask ).sum (dim = 1 ) )
66+ / (padding_mask .sum (dim = 1 ). clamp ( min = 1.0 ) )
7067 ).mean ()
7168 return loss
7269
@@ -211,21 +208,21 @@ def _qwen3_hf_to_vllm(self, saved_sd):
211208 return load_sd
212209
213210 @endpoint
214- async def train_step (self , batch : list [Episode ]):
215- batch = batch [self .dp_rank ]
216- pad_id = batch [0 ].pad_id
211+ async def train_step (self , batch : list [list [ Episode ] ]):
212+ microbatch = batch [self .dp_rank ]
213+ pad_id = microbatch [0 ].pad_id
217214
218215 # prepare batch
219- request = [e .request_tensor for e in batch ]
216+ request = [e .request_tensor for e in microbatch ]
220217 request = torch .stack (request ).to (self .device ) # [b x s]
221218
222- response = [e .response_tensor for e in batch ]
219+ response = [e .response_tensor for e in microbatch ]
223220 response = torch .stack (response ).to (self .device ) # [b x s]
224221
225- ref_logprobs = [e .ref_logprobs for e in batch ]
222+ ref_logprobs = [e .ref_logprobs for e in microbatch ]
226223 ref_logprobs = torch .stack (ref_logprobs ).to (self .device ).squeeze () # [b x s]
227224
228- advantages = [e .advantage for e in batch ]
225+ advantages = [e .advantage for e in microbatch ]
229226 advantages = torch .tensor (advantages ).to (self .device ).unsqueeze (- 1 ) # [b x 1]
230227 del batch
231228
@@ -522,10 +519,10 @@ async def continuous_training():
522519 )
523520
524521
525- @parse
526- def recipe_main (cfg : DictConfig ) -> None :
527- asyncio .run (main (cfg ))
522+ if __name__ == "__main__" :
528523
524+ @parse
525+ def _main (cfg ):
526+ asyncio .run (main (cfg ))
529527
530- if __name__ == "__main__" :
531- recipe_main ()
528+ _main () # @parse grabs the cfg from CLI
0 commit comments