Skip to content

Commit e16e057

Browse files
authored
Remove useless comments (#483)
1 parent 8d7cb10 commit e16e057

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def collate(
120120
return inputs, targets
121121

122122

123+
# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
123124
def simple_grpo_loss(
124125
logits: torch.Tensor,
125126
response: torch.Tensor,
@@ -128,12 +129,7 @@ def simple_grpo_loss(
128129
padding_mask: torch.Tensor,
129130
beta: float = 0.1,
130131
) -> torch.Tensor:
131-
"""
132-
Example GRPO Loss Function for RLTrainer
133-
"""
134132
logprobs: torch.Tensor = compute_logprobs(logits, response)
135-
136-
# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
137133
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
138134
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
139135
per_token_loss = -(per_token_policy_loss - beta * kl)
@@ -146,7 +142,6 @@ def simple_grpo_loss(
146142

147143
@dataclass
148144
class RewardActor(ForgeActor):
149-
"""Reward actor that uses a list of scoring functions."""
150145

151146
reward_functions: list[Callable]
152147

@@ -178,14 +173,12 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
178173
Reduce.STD,
179174
)
180175

181-
# avg total reward
182176
record_metric(
183177
"reward/evaluate_response/avg_total_reward",
184178
reward,
185179
Reduce.MEAN,
186180
)
187181

188-
# count fn calls
189182
record_metric(
190183
f"reward/evaluate_response/count_{reward_fn_name}_calls",
191184
1,
@@ -198,8 +191,6 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
198191

199192
@dataclass
200193
class ComputeAdvantages(ForgeActor):
201-
"""Compute advantages for GRPO using reward signals."""
202-
203194
@endpoint
204195
async def compute(self, group: Group) -> list[float]:
205196
# TODO: add batch processing
@@ -255,7 +246,6 @@ async def sample(self) -> dict[str, str] | None:
255246
try:
256247
sample = next(self._iterator)
257248

258-
# Record dataset metrics
259249
record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
260250
record_metric(
261251
"dataset/sample/avg_sample_len",
@@ -406,13 +396,11 @@ async def continuous_rollouts():
406396
episode.ref_logprobs = ref_logprobs[i]
407397
del ref_logprobs, input_ids
408398

409-
# Calculate advantages and add to replay buffer
410399
advantages = await compute_advantages.compute.call_one(episodes)
411400
for episode, advantage in zip(episodes, advantages):
412401
episode.advantage = advantage
413402
await replay_buffer.add.call_one(episode)
414403

415-
# Log metrics
416404
rollout_count += 1
417405
record_metric(
418406
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM

0 commit comments

Comments
 (0)