-
Notifications
You must be signed in to change notification settings - Fork 39
Fix loss masking #445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix loss masking #445
Conversation
|
|
||
| labels = batch.tokens.crop(labels_begin, labels_end).tokens | ||
|
|
||
| loss_mask = labels >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really what we want? We can't train the model to produce these labels, but it might make sense to compute other losses?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we skip this when not needed?
|
|
||
| if ( | ||
| self._config.head.distillation_model is not None | ||
| or self._config.decoder.block.distillation_model is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Activation distillation ignores loss_mask, it uses activation_mask instead.
Does that even make sense? These refer to token prediction which isn't really a thing at the activation stage. I guess we could take the next token but that raises several concerns (especially with MTP). Actually I think we shouldn't mask those. They may not be used for next token prediction, but the keys and values resulting from these activations are used in further down in the sequence, which means we do train these activations. |
✨ Description
Addresses #442
loss_masksshould include padding and image placeholder tokensTODO:
🔍 Type of change
Select all that apply:
📝 Changes
List the key changes introduced in this PR:
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
📊 Performance Impact Details
If there is any impact on performance, describe it and provide benchmark results, if applicable:
🗒️ Additional Notes
Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.