Skip to content

Commit 1f9e349

Browse files
committed
Revert unintended changes from previous commit
Only src/forge/__init__.py was intended to be changed.
1 parent 3627ec1 commit 1f9e349

File tree

3 files changed

+6
-30
lines changed

3 files changed

+6
-30
lines changed

.claude/settings.local.json

Lines changed: 0 additions & 14 deletions
This file was deleted.

apps/grpo/main.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,6 @@
66

77
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88

9-
# Patch importlib.metadata.distributions before wandb imports it
10-
# to filter out packages with None metadata
11-
import importlib.metadata
12-
_original_distributions = importlib.metadata.distributions
13-
14-
def _patched_distributions():
15-
"""Filter out distributions with None metadata"""
16-
for dist in _original_distributions():
17-
if dist.metadata is not None:
18-
yield dist
19-
20-
importlib.metadata.distributions = _patched_distributions
21-
229
import asyncio
2310
import time
2411
import uuid
@@ -138,9 +125,12 @@ def simple_grpo_loss(
138125
ref_logprobs: torch.Tensor,
139126
advantages: torch.Tensor,
140127
padding_mask: torch.Tensor,
128+
beta: float = 0.1,
141129
) -> torch.Tensor:
142130
logprobs: torch.Tensor = compute_logprobs(logits, response)
143-
per_token_loss = torch.exp(logprobs - logprobs.detach()) * advantages.detach()
131+
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
132+
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
133+
per_token_loss = -(per_token_policy_loss - beta * kl)
144134
loss = (
145135
((per_token_loss * padding_mask).sum(dim=1))
146136
/ (padding_mask.sum(dim=1).clamp(min=1.0))

apps/grpo/qwen3_1_7b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size
77
max_req_tokens: 1024
88
max_res_tokens: 1024
99
model: "Qwen/Qwen3-1.7B"
10-
off_by_n: 0 # Off by one by default
10+
off_by_n: 1 # Off by one by default
1111

1212
# Main loop configuration
13-
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
13+
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
1414

1515

1616
# Observability configuration

0 commit comments

Comments
 (0)