We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 71cf182 commit b487c1bCopy full SHA for b487c1b
apps/grpo/main.py
@@ -129,7 +129,7 @@ def simple_grpo_loss(
129
ref_logprobs: torch.Tensor,
130
advantages: torch.Tensor,
131
padding_mask: torch.Tensor,
132
- beta: float = 1e-5,
+ beta: float = 1e-6,
133
) -> torch.Tensor:
134
logprobs: torch.Tensor = compute_logprobs(logits, response)
135
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
apps/grpo/qwen3_8b.yaml
@@ -2,8 +2,8 @@
2
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
3
4
# Global configuration
5
-group_size: 8
6
-local_batch_size: 8 # per-device batch size
+group_size: 16
+local_batch_size: 4 # per-device batch size
7
max_req_tokens: 1024
8
max_res_tokens: 2048
9
model: "Qwen/Qwen3-8B"
0 commit comments