Skip to content

Conversation

@winglian
Copy link
Collaborator

@winglian winglian commented Aug 12, 2025

Description

paper: https://arxiv.org/abs/2508.05629

support DFT, but requires chunked_cross_entropy since we implement it there.

use_dynamic_finetuning: true
chunked_cross_entropy: true

Summary by CodeRabbit

  • New Features

    • Added an optional setting to enable dynamic fine-tuning that adapts loss weighting during training, integrated with chunked cross-entropy.
  • Validation

    • Introduced pre-run checks requiring chunked cross-entropy when dynamic fine-tuning is enabled, with clear error messages for misconfiguration.
    • Added a guard preventing incompatible combinations of parameter offload and specific 8-bit/4-bit optimizers in FSDP2.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 12, 2025

📝 Walkthrough

Walkthrough

Introduces a use_dynamic_finetuning configuration and threads it into the chunked cross-entropy loss patch. Updates loss computation to optionally scale per-token losses by probabilities when DFT is enabled, adds schema field and validation enforcing DFT requires chunked CE, and adjusts PatchManager to pass the new flag.

Changes

Cohort / File(s) Summary of changes
Patch manager integration
src/axolotl/loaders/patch_manager.py
Passes use_dft flag from config (use_dynamic_finetuning) into patch_chunked_ce_loss_fn, with or without num_chunks.
Chunked CE loss and patch plumbing
src/axolotl/monkeypatch/loss/chunked.py
Adds use_dft across CEWithChunkedOutputLoss, builders, and patch function. Changes loss to per-token (reduction="none"); in DFT mode, scales by token probability before summation.
Config schema
src/axolotl/utils/schemas/config.py
Adds optional boolean use_dynamic_finetuning to AxolotlInputConfig with description.
Validation mixins
src/axolotl/utils/schemas/validation.py
Adds CELossValidationMixin to enforce chunked CE when DFT is true. Adds optimizer/FSDP offload validation. ValidationMixin now inherits the new mixin.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • chunked cross entropy loss #2625: Extends the same chunked CE patch by threading the use_dft flag through identical call sites and PatchManager, directly connected to this change.

Suggested labels

ready to merge

Suggested reviewers

  • NanoCode012
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch dynamic-sft

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)

393-403: Simplify nested if statements.

The nested if statements can be combined for better readability.

Apply this diff to simplify the code:

-    @model_validator(mode="before")
-    @classmethod
-    def check_ao_optim_fsdp2_offload(cls, data):
-        if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
-            "offload_params"
-        ):
-            if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
-                raise ValueError(
-                    "low bit ao optimizers is not supported with FSDP2 w/ offload_params."
-                )
-        return data
+    @model_validator(mode="before")
+    @classmethod
+    def check_ao_optim_fsdp2_offload(cls, data):
+        if (
+            data.get("fsdp_config")
+            and data.get("fsdp_config", {}).get("offload_params")
+            and data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]
+        ):
+            raise ValueError(
+                "low bit ao optimizers is not supported with FSDP2 w/ offload_params."
+            )
+        return data
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3d45620 and a4b6ff5.

📒 Files selected for processing (4)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/loss/chunked.py (4 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (3 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
src/axolotl/loaders/patch_manager.py (3)
src/axolotl/monkeypatch/loss/chunked.py (1)
  • patch_chunked_ce_loss_fn (157-168)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

396-399: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


534-535: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (9)
src/axolotl/utils/schemas/config.py (1)

587-592: LGTM!

The field is properly added with appropriate documentation.

src/axolotl/utils/schemas/validation.py (2)

528-540: LGTM!

The validation correctly enforces the constraint that dynamic fine-tuning requires chunked cross-entropy. The error message is clear and informative.


1351-1355: LGTM!

The CELossValidationMixin is properly integrated into the ValidationMixin inheritance chain.

src/axolotl/monkeypatch/loss/chunked.py (5)

19-28: LGTM!

The parameter addition is well-structured with proper default values that maintain backward compatibility.


30-62: Excellent implementation of Dynamic Fine-Tuning (DFT)!

The DFT implementation properly weights per-token losses by token probabilities, following the paper's approach. The gradient-free computation and masking logic are correctly implemented.


100-107: LGTM!

The function signature properly threads the use_dft parameter through to the loss construction.


110-114: LGTM!

The use_dft parameter is correctly passed through the function call chain.


157-168: LGTM!

The patching function correctly updates the transformers module with the DFT-enabled loss function.

src/axolotl/loaders/patch_manager.py (1)

104-115: LGTM!

The patch application correctly threads the use_dynamic_finetuning configuration through to the chunked cross-entropy loss patches, handling both cases where chunk count is specified or not.

@codecov
Copy link

codecov bot commented Aug 12, 2025

Codecov Report

❌ Patch coverage is 55.55556% with 16 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/monkeypatch/loss/chunked.py 38.88% 11 Missing ⚠️
src/axolotl/utils/schemas/validation.py 80.00% 3 Missing ⚠️
src/axolotl/loaders/patch_manager.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant