File tree Expand file tree Collapse file tree 4 files changed +42
-17
lines changed Expand file tree Collapse file tree 4 files changed +42
-17
lines changed Original file line number Diff line number Diff line change 1+ {
2+ "permissions" : {
3+ "allow" : [
4+ " Bash(source ~/.bashrc)" ,
5+ " Bash(conda activate forge-monarch-0-1-1)" ,
6+ " Bash(pip install:*)" ,
7+ " Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 pip install:*)" ,
8+ " Bash(python -m pytest:*)" ,
9+ " Bash(https_proxy=http://fwdproxy:8080 http_proxy=http://fwdproxy:8080 git push:*)"
10+ ],
11+ "deny" : [],
12+ "ask" : []
13+ }
14+ }
Original file line number Diff line number Diff line change 66
77# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88
9+ # Patch importlib.metadata.distributions before wandb imports it
10+ # to filter out packages with None metadata
11+ import importlib .metadata
12+ _original_distributions = importlib .metadata .distributions
13+
14+ def _patched_distributions ():
15+ """Filter out distributions with None metadata"""
16+ for dist in _original_distributions ():
17+ if dist .metadata is not None :
18+ yield dist
19+
20+ importlib .metadata .distributions = _patched_distributions
21+
922import asyncio
1023import time
1124import uuid
@@ -125,12 +138,9 @@ def simple_grpo_loss(
125138 ref_logprobs : torch .Tensor ,
126139 advantages : torch .Tensor ,
127140 padding_mask : torch .Tensor ,
128- beta : float = 0.1 ,
129141) -> torch .Tensor :
130142 logprobs : torch .Tensor = compute_logprobs (logits , response )
131- kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
132- per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
133- per_token_loss = - (per_token_policy_loss - beta * kl )
143+ per_token_loss = torch .exp (logprobs - logprobs .detach ()) * advantages .detach ()
134144 loss = (
135145 ((per_token_loss * padding_mask ).sum (dim = 1 ))
136146 / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
Original file line number Diff line number Diff line change @@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size
77max_req_tokens : 1024
88max_res_tokens : 1024
99model : " Qwen/Qwen3-1.7B"
10- off_by_n : 1 # Off by one by default
10+ off_by_n : 0 # Off by one by default
1111
1212# Main loop configuration
13- rollout_threads : 1 # Recommended to set equal to policy.num_replicas
13+ rollout_threads : 1 # Recommended to set equal to policy.num_replicas
1414
1515
1616# Observability configuration
Original file line number Diff line number Diff line change 2525# to filter out packages with None metadata
2626import importlib .metadata
2727
28- _original_distributions = importlib .metadata .distributions
29-
30-
31- def _patched_distributions ():
32- """Filter out distributions with None metadata"""
33- for dist in _original_distributions ():
34- if dist .metadata is not None :
35- yield dist
36-
37-
38- importlib .metadata .distributions = _patched_distributions
28+ # Guard to ensure this runs only once
29+ if not hasattr (importlib .metadata , "_distributions_patched" ):
30+ _original_distributions = importlib .metadata .distributions
31+
32+ def _patched_distributions ():
33+ """Filter out distributions with None metadata"""
34+ for dist in _original_distributions ():
35+ if dist .metadata is not None :
36+ yield dist
37+
38+ importlib .metadata .distributions = _patched_distributions
39+ importlib .metadata ._distributions_patched = True
You can’t perform that action at this time.
0 commit comments