-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
use dynamic finetuning with chunked cross entropy #3057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughIntroduces a use_dynamic_finetuning configuration and threads it into the chunked cross-entropy loss patch. Updates loss computation to optionally scale per-token losses by probabilities when DFT is enabled, adds schema field and validation enforcing DFT requires chunked CE, and adjusts PatchManager to pass the new flag. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)
393-403: Simplify nested if statements.The nested if statements can be combined for better readability.
Apply this diff to simplify the code:
- @model_validator(mode="before") - @classmethod - def check_ao_optim_fsdp2_offload(cls, data): - if data.get("fsdp_config") and data.get("fsdp_config", {}).get( - "offload_params" - ): - if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]: - raise ValueError( - "low bit ao optimizers is not supported with FSDP2 w/ offload_params." - ) - return data + @model_validator(mode="before") + @classmethod + def check_ao_optim_fsdp2_offload(cls, data): + if ( + data.get("fsdp_config") + and data.get("fsdp_config", {}).get("offload_params") + and data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"] + ): + raise ValueError( + "low bit ao optimizers is not supported with FSDP2 w/ offload_params." + ) + return data
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/loss/chunked.py(4 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (3)
src/axolotl/monkeypatch/loss/chunked.py (1)
patch_chunked_ce_loss_fn(157-168)src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)tests/test_exact_deduplication.py (1)
cfg(201-216)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
396-399: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
534-535: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (9)
src/axolotl/utils/schemas/config.py (1)
587-592: LGTM!The field is properly added with appropriate documentation.
src/axolotl/utils/schemas/validation.py (2)
528-540: LGTM!The validation correctly enforces the constraint that dynamic fine-tuning requires chunked cross-entropy. The error message is clear and informative.
1351-1355: LGTM!The
CELossValidationMixinis properly integrated into theValidationMixininheritance chain.src/axolotl/monkeypatch/loss/chunked.py (5)
19-28: LGTM!The parameter addition is well-structured with proper default values that maintain backward compatibility.
30-62: Excellent implementation of Dynamic Fine-Tuning (DFT)!The DFT implementation properly weights per-token losses by token probabilities, following the paper's approach. The gradient-free computation and masking logic are correctly implemented.
100-107: LGTM!The function signature properly threads the
use_dftparameter through to the loss construction.
110-114: LGTM!The
use_dftparameter is correctly passed through the function call chain.
157-168: LGTM!The patching function correctly updates the transformers module with the DFT-enabled loss function.
src/axolotl/loaders/patch_manager.py (1)
104-115: LGTM!The patch application correctly threads the
use_dynamic_finetuningconfiguration through to the chunked cross-entropy loss patches, handling both cases where chunk count is specified or not.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
Description
paper: https://arxiv.org/abs/2508.05629
support DFT, but requires chunked_cross_entropy since we implement it there.
Summary by CodeRabbit
New Features
Validation