-
Notifications
You must be signed in to change notification settings - Fork 69
Temporarily use no packing in SFT #614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| if self.job_config.training.compile: | ||
| raise ValueError( | ||
| "training.compile=True is not currently supported. " | ||
| "Compile is only supported with flex attention enabled, which requires PyTorch nightly. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Objection to start a main issue tracking the nightly build?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! But can we first nail down the different subtasks via the Google Doc I just shared? Then we can translate to a GI.
| # Flatten if all are lists | ||
| if all(isinstance(item, list) for item in result[key]): | ||
| result[key] = [item for sublist in result[key] for item in sublist] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this a common practice? Feels like unnecessary operation / tribal knowledge
| Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100). | ||
| Non-tensor fields (like metrics) are collected into lists and flattened | ||
| if all items are lists. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it common practice to assume tokens and labels are the keys for collate_padded?
Context
Our current SFT recipe using packing. However, in order to use packing you need to pass in a block_causal mask in the forward pass. We construct this mask, but it is ignored in titan b/c they only allow additional mask to be passed in if the model definition specifies this and denotes that the model is using flex-attn. Since we are unable to control the exact model definitions ourselves, this is a temporary fix to ensure that our training is correct.
The TRUE fix(es) would be:
Changes
SFT main.py
Collate.py
Configs:
Testing
Wandb logs: https://wandb.ai/jcummings/sft-training
See below output with compile=True
Open questions