@@ -120,6 +120,7 @@ def collate(
120120 return inputs , targets
121121
122122
123+ # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
123124def simple_grpo_loss (
124125 logits : torch .Tensor ,
125126 response : torch .Tensor ,
@@ -128,12 +129,7 @@ def simple_grpo_loss(
128129 padding_mask : torch .Tensor ,
129130 beta : float = 0.1 ,
130131) -> torch .Tensor :
131- """
132- Example GRPO Loss Function for RLTrainer
133- """
134132 logprobs : torch .Tensor = compute_logprobs (logits , response )
135-
136- # Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
137133 kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
138134 per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
139135 per_token_loss = - (per_token_policy_loss - beta * kl )
@@ -146,7 +142,6 @@ def simple_grpo_loss(
146142
147143@dataclass
148144class RewardActor (ForgeActor ):
149- """Reward actor that uses a list of scoring functions."""
150145
151146 reward_functions : list [Callable ]
152147
@@ -178,14 +173,12 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
178173 Reduce .STD ,
179174 )
180175
181- # avg total reward
182176 record_metric (
183177 "reward/evaluate_response/avg_total_reward" ,
184178 reward ,
185179 Reduce .MEAN ,
186180 )
187181
188- # count fn calls
189182 record_metric (
190183 f"reward/evaluate_response/count_{ reward_fn_name } _calls" ,
191184 1 ,
@@ -198,8 +191,6 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
198191
199192@dataclass
200193class ComputeAdvantages (ForgeActor ):
201- """Compute advantages for GRPO using reward signals."""
202-
203194 @endpoint
204195 async def compute (self , group : Group ) -> list [float ]:
205196 # TODO: add batch processing
@@ -255,7 +246,6 @@ async def sample(self) -> dict[str, str] | None:
255246 try :
256247 sample = next (self ._iterator )
257248
258- # Record dataset metrics
259249 record_metric ("dataset/sample/count_samples_generated" , 1 , Reduce .SUM )
260250 record_metric (
261251 "dataset/sample/avg_sample_len" ,
@@ -406,13 +396,11 @@ async def continuous_rollouts():
406396 episode .ref_logprobs = ref_logprobs [i ]
407397 del ref_logprobs , input_ids
408398
409- # Calculate advantages and add to replay buffer
410399 advantages = await compute_advantages .compute .call_one (episodes )
411400 for episode , advantage in zip (episodes , advantages ):
412401 episode .advantage = advantage
413402 await replay_buffer .add .call_one (episode )
414403
415- # Log metrics
416404 rollout_count += 1
417405 record_metric (
418406 "main/continuous_rollouts/count_rollout_iterations" , 1 , Reduce .SUM
0 commit comments