Skip to content

Commit 3627ec1

Browse files
committed
fix
1 parent fef0230 commit 3627ec1

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

.claude/settings.local.json

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Bash(source ~/.bashrc)",
5+
"Bash(conda activate forge-monarch-0-1-1)",
6+
"Bash(pip install:*)",
7+
"Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 pip install:*)",
8+
"Bash(python -m pytest:*)",
9+
"Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 git push:*)"
10+
],
11+
"deny": [],
12+
"ask": []
13+
}
14+
}

apps/grpo/main.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,19 @@
66

77
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88

9+
# Patch importlib.metadata.distributions before wandb imports it
10+
# to filter out packages with None metadata
11+
import importlib.metadata
12+
_original_distributions = importlib.metadata.distributions
13+
14+
def _patched_distributions():
15+
"""Filter out distributions with None metadata"""
16+
for dist in _original_distributions():
17+
if dist.metadata is not None:
18+
yield dist
19+
20+
importlib.metadata.distributions = _patched_distributions
21+
922
import asyncio
1023
import time
1124
import uuid
@@ -125,12 +138,9 @@ def simple_grpo_loss(
125138
ref_logprobs: torch.Tensor,
126139
advantages: torch.Tensor,
127140
padding_mask: torch.Tensor,
128-
beta: float = 0.1,
129141
) -> torch.Tensor:
130142
logprobs: torch.Tensor = compute_logprobs(logits, response)
131-
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
132-
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
133-
per_token_loss = -(per_token_policy_loss - beta * kl)
143+
per_token_loss = torch.exp(logprobs - logprobs.detach()) * advantages.detach()
134144
loss = (
135145
((per_token_loss * padding_mask).sum(dim=1))
136146
/ (padding_mask.sum(dim=1).clamp(min=1.0))

apps/grpo/qwen3_1_7b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size
77
max_req_tokens: 1024
88
max_res_tokens: 1024
99
model: "Qwen/Qwen3-1.7B"
10-
off_by_n: 1 # Off by one by default
10+
off_by_n: 0 # Off by one by default
1111

1212
# Main loop configuration
13-
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
13+
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
1414

1515

1616
# Observability configuration

src/forge/__init__.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@
2525
# to filter out packages with None metadata
2626
import importlib.metadata
2727

28-
_original_distributions = importlib.metadata.distributions
29-
30-
31-
def _patched_distributions():
32-
"""Filter out distributions with None metadata"""
33-
for dist in _original_distributions():
34-
if dist.metadata is not None:
35-
yield dist
36-
37-
38-
importlib.metadata.distributions = _patched_distributions
28+
# Guard to ensure this runs only once
29+
if not hasattr(importlib.metadata, "_distributions_patched"):
30+
_original_distributions = importlib.metadata.distributions
31+
32+
def _patched_distributions():
33+
"""Filter out distributions with None metadata"""
34+
for dist in _original_distributions():
35+
if dist.metadata is not None:
36+
yield dist
37+
38+
importlib.metadata.distributions = _patched_distributions
39+
importlib.metadata._distributions_patched = True

0 commit comments

Comments
 (0)