Skip to content

Commit b487c1b

Browse files
committed
hyperparameter
1 parent 71cf182 commit b487c1b

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def simple_grpo_loss(
129129
ref_logprobs: torch.Tensor,
130130
advantages: torch.Tensor,
131131
padding_mask: torch.Tensor,
132-
beta: float = 1e-5,
132+
beta: float = 1e-6,
133133
) -> torch.Tensor:
134134
logprobs: torch.Tensor = compute_logprobs(logits, response)
135135
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1

apps/grpo/qwen3_8b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
33

44
# Global configuration
5-
group_size: 8
6-
local_batch_size: 8 # per-device batch size
5+
group_size: 16
6+
local_batch_size: 4 # per-device batch size
77
max_req_tokens: 1024
88
max_res_tokens: 2048
99
model: "Qwen/Qwen3-8B"

0 commit comments

Comments
 (0)