Skip to content

Commit ed9c759

Browse files
committed
Update AsyncGRPOConfig chunk_lm_head default to 8192
1 parent 357c6ad commit ed9c759

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

trl/experimental/async_grpo/async_grpo_config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,11 @@ class AsyncGRPOConfig(_BaseConfig):
6161
6262
> Parameters that control the LM head
6363
64-
chunk_lm_head (`int` or `None`, *optional*, defaults to `None`):
64+
chunk_lm_head (`int` or `None`, *optional*, defaults to `8192`):
6565
Chunk size for the fused LM head. When set, the lm_head computes log-probs and entropy without
6666
materializing the full `[batch, seq, vocab]` logits tensor, processing the vocabulary in chunks of this
6767
size instead. Reduces peak memory at the cost of extra matmuls. If `None`, uses the standard full-logits
6868
path.
69-
7069
> Parameters that control the async rollout pipeline
7170
7271
max_inflight_tasks (`int`, *optional*, defaults to `-1`):
@@ -168,7 +167,7 @@ class AsyncGRPOConfig(_BaseConfig):
168167

169168
# Parameters that control the LM head
170169
chunk_lm_head_size: int | None = field(
171-
default=None,
170+
default=8192,
172171
metadata={
173172
"help": "Chunk size for the fused LM head. When set, the lm_head computes log-probs and entropy "
174173
"without materializing the full [batch, seq, vocab] logits tensor, processing the vocabulary in "
@@ -177,7 +176,6 @@ class AsyncGRPOConfig(_BaseConfig):
177176
"forward pass)."
178177
},
179178
)
180-
181179
# Parameters that control the async rollout pipeline
182180
max_inflight_tasks: int = field(
183181
default=-1,

0 commit comments

Comments
 (0)