Skip to content

Conversation

@joecummings
Copy link
Member

@joecummings joecummings commented Dec 2, 2025

Context

Our current SFT recipe using packing. However, in order to use packing you need to pass in a block_causal mask in the forward pass. We construct this mask, but it is ignored in titan b/c they only allow additional mask to be passed in if the model definition specifies this and denotes that the model is using flex-attn. Since we are unable to control the exact model definitions ourselves, this is a temporary fix to ensure that our training is correct.

The TRUE fix(es) would be:

  1. Move to nightlies everything and push qwen3_flex and llama3_flex versions to titan, then default to using those
  2. Work with Titan to allow us to override these model definitions for things like this where we might want to try out Flex and normal attention

Changes

SFT main.py

  • Remove references to packing code
  • Import and use new padding function
  • Add validation to confirm that compile cannot be used

Collate.py

  • Add new function that pads to longest seq in the batch

Configs:

  • Increased batch size from 1 - 8 now that we are unable to fit multiple sequences in a single sample

Testing

  1. Works with default configs as confirmed by training loss and eval loss decreasing

Wandb logs: https://wandb.ai/jcummings/sft-training

  1. Validation works correctly

See below output with compile=True

[ForgeSFTRecipe-2/8] 2025-12-03 09:47:27 CRITICAL Unhandled exception in actor endpoint
Traceback (most recent call last):
  File "/home/jrcummings/.conda/envs/forge-uv/lib/python3.12/site-packages/monarch/_src/actor/actor_mesh.py", line 935, in handle
    result = await the_method(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jrcummings/projects/joe-forge/apps/sft/main.py", line 100, in setup
    raise ValueError(
ValueError: training.compile=True is not currently supported. Compile is only supported with flex attention enabled, which requires PyTorch nightly. Please set training.compile=false in your config.

Open questions

  • Why is the number of samples processed different between Llama3 8b and Qwen3 8b? Their seq length should be the same.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 2, 2025
@joecummings joecummings marked this pull request as ready for review December 3, 2025 18:17
if self.job_config.training.compile:
raise ValueError(
"training.compile=True is not currently supported. "
"Compile is only supported with flex attention enabled, which requires PyTorch nightly. "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Objection to start a main issue tracking the nightly build?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! But can we first nail down the different subtasks via the Google Doc I just shared? Then we can translate to a GI.

Comment on lines +64 to +66
# Flatten if all are lists
if all(isinstance(item, list) for item in result[key]):
result[key] = [item for sublist in result[key] for item in sublist]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a common practice? Feels like unnecessary operation / tribal knowledge

Comment on lines +19 to +21
Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100).
Non-tensor fields (like metrics) are collected into lists and flattened
if all items are lists.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it common practice to assume tokens and labels are the keys for collate_padded?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants