File tree Expand file tree Collapse file tree 3 files changed +6
-30
lines changed Expand file tree Collapse file tree 3 files changed +6
-30
lines changed Load Diff This file was deleted.
Original file line number Diff line number Diff line change 66
77# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
88
9- # Patch importlib.metadata.distributions before wandb imports it
10- # to filter out packages with None metadata
11- import importlib .metadata
12- _original_distributions = importlib .metadata .distributions
13-
14- def _patched_distributions ():
15- """Filter out distributions with None metadata"""
16- for dist in _original_distributions ():
17- if dist .metadata is not None :
18- yield dist
19-
20- importlib .metadata .distributions = _patched_distributions
21-
229import asyncio
2310import time
2411import uuid
@@ -138,9 +125,12 @@ def simple_grpo_loss(
138125 ref_logprobs : torch .Tensor ,
139126 advantages : torch .Tensor ,
140127 padding_mask : torch .Tensor ,
128+ beta : float = 0.1 ,
141129) -> torch .Tensor :
142130 logprobs : torch .Tensor = compute_logprobs (logits , response )
143- per_token_loss = torch .exp (logprobs - logprobs .detach ()) * advantages .detach ()
131+ kl = torch .exp (ref_logprobs - logprobs ) - (ref_logprobs - logprobs ) - 1
132+ per_token_policy_loss = torch .exp (logprobs - logprobs .detach ()) * advantages
133+ per_token_loss = - (per_token_policy_loss - beta * kl )
144134 loss = (
145135 ((per_token_loss * padding_mask ).sum (dim = 1 ))
146136 / (padding_mask .sum (dim = 1 ).clamp (min = 1.0 ))
Original file line number Diff line number Diff line change @@ -7,10 +7,10 @@ local_batch_size: 16 # per-device batch size
77max_req_tokens : 1024
88max_res_tokens : 1024
99model : " Qwen/Qwen3-1.7B"
10- off_by_n : 0 # Off by one by default
10+ off_by_n : 1 # Off by one by default
1111
1212# Main loop configuration
13- rollout_threads : 1 # Recommended to set equal to policy.num_replicas
13+ rollout_threads : 1 # Recommended to set equal to policy.num_replicas
1414
1515
1616# Observability configuration
You can’t perform that action at this time.
0 commit comments