-
Notifications
You must be signed in to change notification settings - Fork 47
Closed
Description
🐛 Describe the bug
The GRPO example generates 512 tokens per prompt, which is too small for reasoning models to consistently terminate their rollouts with EOS. However, changing this to more realistic values like 16k tokens as follows:
# Global configuration
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
- max_res_tokens: 512
+ max_res_tokens: 16384
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by defaultproduces an exception on broadcast:
Traceback (most recent call last):
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/monarch/_src/actor/actor_mesh.py", line 932, in handle
result = await the_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/git/torchforge/src/forge/actors/reference_model.py", line 181, in forward
logits = self.model(input_ids)
^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 491, in forward
h = layer(h, self.rope_cache, attention_masks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 345, in forward
x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 219, in forward
xq, xk = apply_rotary_emb(xq, xk, rope_cache)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 91, in apply_rotary_emb
rope_cache = reshape_for_broadcast(rope_cache, xq)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 79, in reshape_for_broadcast
assert rope_cache.shape == (seqlen, head_dim * 2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
I ran this with the above change to qwen3_1_7b.yaml and:
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yamlVersions
No response
felipemello1
Metadata
Metadata
Assignees
Labels
No labels