-
Notifications
You must be signed in to change notification settings - Fork 3.3k
feat: Add MFT (Minifinetuning) loss support for knowledge distillation #14298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add MFTLoss class implementing Minifinetuning distillation loss from https://arxiv.org/abs/2506.15702 - Implement corrected distribution preparation with threshold-based teacher probability adjustment - Add support for both incorrect argmax and separation threshold corrections - Update DistillationConfig to support MFT mode with configurable threshold parameter - Integrate MFT loss option in distillation pipeline with automatic selection based on config - Add validation for MFT threshold parameter (must be between 0 and 1) - Note: MFT loss currently does not support tensor model parallelism The MFT loss provides an alternative to standard KL divergence by correcting teacher distributions based on ground truth labels and a configurable threshold, potentially improving distillation quality for language models. Signed-off-by: pbelcak <[email protected]>
Signed-off-by: pbelcak <[email protected]>
…commit. Signed-off-by: pbelcak <[email protected]>
|
(forgot to update CHANGELOG.md, made no other changes) |
|
@sharathts all good? |
|
@pbelcak LGTM, approved! |
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
|
@pbelcak could you cherry-pick the commit to fix pylint and I will try again if we can solve the lint failure.
|
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
|
I see that the changes needed here are largely cosmetic; I plan to have a look at it this week. |
|
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
Signed-off-by: Zhiyu Li <[email protected]>
|
Hi @ZhiyuLi-Nvidia, thanks for the fix, it's done. Can we merge? |
Signed-off-by: Zhiyu Li <[email protected]>
Hi, @pbelcak all tests passed except the low test coverage. Could you finally add a unit test and rebase the branch? |
What does this PR do ?
Adds MFT (Minifinetuning; NVR paper) loss support for knowledge distillation with configurable threshold-based teacher probability correction.
The MFT loss provides an alternative to standard KL divergence by correcting teacher distributions based on ground truth labels and a configurable threshold, potentially improving distillation quality for language models.
Collection:
llm/modeloptChangelog
MFTLossclass implementing Minifinetuning distillation loss based on https://arxiv.org/abs/2506.15702_prepare_corrected_distributionsmethod for threshold-based teacher probability correctionDistillationConfigdataclass withuse_mftandmft_thresholdparameters__post_init__methodload_distillation_configfunctionutils.pyfor proper integrationUsage
You can enable MFT loss in your distillation configuration:
In your distillation YAML config file:
Or, programmatically:
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
merge_requests/2785. No ModelOpt tests (except for one quantization check) are currently present in the NeMo Framework tests, so I did not bring the tests here.PR Type: