Skip to content

[GRPO app] Example fails when max_res_tokens: 16384 with assert rope_cache.shape == (seqlen, head_dim * 2) #495

@lewtun

Description

@lewtun

🐛 Describe the bug

The GRPO example generates 512 tokens per prompt, which is too small for reasoning models to consistently terminate their rollouts with EOS. However, changing this to more realistic values like 16k tokens as follows:

# Global configuration
group_size: 8
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
- max_res_tokens: 512
+ max_res_tokens: 16384
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default

produces an exception on broadcast:

Traceback (most recent call last):
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/monarch/_src/actor/actor_mesh.py", line 932, in handle
    result = await the_method(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/git/torchforge/src/forge/actors/reference_model.py", line 181, in forward
    logits = self.model(input_ids)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 491, in forward
    h = layer(h, self.rope_cache, attention_masks)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 345, in forward
    x = x + self.attention(self.attention_norm(x), rope_cache, attention_masks)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 219, in forward
    xq, xk = apply_rotary_emb(xq, xk, rope_cache)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 91, in apply_rotary_emb
    rope_cache = reshape_for_broadcast(rope_cache, xq)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/fsx/lewis/miniconda3/envs/forge/lib/python3.12/site-packages/torchtitan/models/qwen3/model/model.py", line 79, in reshape_for_broadcast
    assert rope_cache.shape == (seqlen, head_dim * 2)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

I ran this with the above change to qwen3_1_7b.yaml and:

python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml

Versions

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions