Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def collate(
return inputs, targets


# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
Expand All @@ -128,12 +129,7 @@ def simple_grpo_loss(
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
"""
Example GRPO Loss Function for RLTrainer
"""
logprobs: torch.Tensor = compute_logprobs(logits, response)

# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
Expand All @@ -146,7 +142,6 @@ def simple_grpo_loss(

@dataclass
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""

reward_functions: list[Callable]

Expand Down Expand Up @@ -178,14 +173,12 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
Reduce.STD,
)

# avg total reward
record_metric(
"reward/evaluate_response/avg_total_reward",
reward,
Reduce.MEAN,
)

# count fn calls
record_metric(
f"reward/evaluate_response/count_{reward_fn_name}_calls",
1,
Expand All @@ -198,8 +191,6 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl

@dataclass
class ComputeAdvantages(ForgeActor):
"""Compute advantages for GRPO using reward signals."""

@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
Expand Down Expand Up @@ -255,7 +246,6 @@ async def sample(self) -> dict[str, str] | None:
try:
sample = next(self._iterator)

# Record dataset metrics
record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
record_metric(
"dataset/sample/avg_sample_len",
Expand Down Expand Up @@ -406,13 +396,11 @@ async def continuous_rollouts():
episode.ref_logprobs = ref_logprobs[i]
del ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:O i kind like this. It works as a visual way to segment code blocks

advantages = await compute_advantages.compute.call_one(episodes)
for episode, advantage in zip(episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)

# Log metrics
rollout_count += 1
record_metric(
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
Expand Down
Loading