Skip to content

Conversation

ebsmothers
Copy link
Contributor

After the changes in pytorch/torchtitan#1616, we need to explicitly initialize the attention mask in our trainer code.

Test plan: hacked Llama3 8B in my local titan code to enable flex as in this config. Then ran:

forge run --nproc_per_node 2 apps/sft/main.py --config apps/sft/llama3_8b.yaml

On main:

...
[rank0]:     output = self.sdpa(xq, xk, xv)
[rank0]:   File "/home/ebs/.fbpkg_conda_envs/forge-6f4168f/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ebs/.fbpkg_conda_envs/forge-6f4168f/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ebs/torchtitan/torchtitan/models/attention.py", line 88, in forward
[rank0]:     block_mask = FlexAttention.block_masks[self.mask_key]
[rank0]: KeyError: ('block_causal', None)

On this branch:

...
4|Loss: 12.06763744354248:   0%|▉                    | 5/1000 [00:08<23:58,  1.45s/it]

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 26, 2025
@ebsmothers ebsmothers merged commit 89903fa into meta-pytorch:main Aug 26, 2025
4 checks passed
@ebsmothers ebsmothers mentioned this pull request Aug 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants