From 625b6e826fb1ed8b113dc11c49dc515f814b41f3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Aug 2025 18:48:32 -0400 Subject: [PATCH 1/2] use dynamic finetuning with chunked cross entropy --- src/axolotl/loaders/patch_manager.py | 7 +++- src/axolotl/monkeypatch/loss/chunked.py | 52 ++++++++++++++++++++----- src/axolotl/utils/schemas/config.py | 7 ++++ 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index f1ca3c7259..7373db491b 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -106,9 +106,12 @@ def _apply_chunked_cross_entropy_patch(self): from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn if self.cfg.chunked_cross_entropy_num_chunks: - patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks) + patch_chunked_ce_loss_fn( + self.cfg.chunked_cross_entropy_num_chunks, + use_dft=self.cfg.use_dynamic_finetuning, + ) else: - patch_chunked_ce_loss_fn() + patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning) def _apply_fsdp_patches(self): """Apply patches for FSDP configurations.""" diff --git a/src/axolotl/monkeypatch/loss/chunked.py b/src/axolotl/monkeypatch/loss/chunked.py index 0a9d0de82c..8131ffba28 100644 --- a/src/axolotl/monkeypatch/loss/chunked.py +++ b/src/axolotl/monkeypatch/loss/chunked.py @@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module): For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 """ - def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + def __init__( + self, + num_output_chunks: int = 8, + ignore_index: int = -100, + use_dft: bool = False, + ): super().__init__() self.num_output_chunks = num_output_chunks self.ignore_index = ignore_index + self.use_dft = use_dft def compute_cross_entropy( self, @@ -30,10 +36,30 @@ def compute_cross_entropy( """ Upcast logits to fp32 and compute cross entropy loss. """ - return F.cross_entropy( - logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ce_loss = F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="none" ) + if self.use_dft: + # Compute probabilities and gather the ones corresponding to labels + with torch.no_grad(): # Stop gradient + probs = torch.softmax(logits.float(), dim=-1) + # Create mask for valid tokens (not ignore_index) + valid_mask = labels != self.ignore_index + # Gather probabilities for the correct tokens + label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1) + # Apply mask to only scale valid tokens + label_probs = label_probs * valid_mask + # Avoid multiplication by 0 for ignored tokens + label_probs = torch.where( + valid_mask, label_probs, torch.ones_like(label_probs) + ) + + # Scale the loss by the probability (DFT) + ce_loss = ce_loss * label_probs + + return ce_loss.sum() + def forward( self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum" ) -> torch.Tensor: @@ -71,16 +97,20 @@ def forward( return total_loss / total_elements -def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): - loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index) +def _build_chunked_ce_loss_fn( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): + loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft) loss_fn_ce.compute_cross_entropy = torch.compile( loss_fn_ce.compute_cross_entropy, backend="inductor" ) return loss_fn_ce -def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100): - loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index) +def get_causal_lm_loss( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): + loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft) def chunked_fix_cross_entropy( source, @@ -124,10 +154,14 @@ def for_causal_lm_chunked_loss( return for_causal_lm_chunked_loss -def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100): +def patch_chunked_ce_loss_fn( + num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False +): import transformers.loss.loss_utils - for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index) + for_causal_lm_chunked_loss = get_causal_lm_loss( + num_output_chunks, ignore_index, use_dft + ) transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = ( for_causal_lm_chunked_loss diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 21e99c0483..41a365c289 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -584,6 +584,13 @@ class AxolotlInputConfig( }, ) + use_dynamic_finetuning: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use dynamic fine-tuning for scaled SFT gradients." + }, + ) + chunked_cross_entropy: bool | None = Field( default=None, json_schema_extra={ From a4b6ff548cd81d30dea5a35b4dd74d09de694b18 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 11 Aug 2025 21:28:07 -0400 Subject: [PATCH 2/2] add validation for DFT --- src/axolotl/utils/schemas/validation.py | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 72991c9470..feb60514a4 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -390,6 +390,18 @@ def check_fp8_config(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_ao_optim_fsdp2_offload(cls, data): + if data.get("fsdp_config") and data.get("fsdp_config", {}).get( + "offload_params" + ): + if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]: + raise ValueError( + "low bit ao optimizers is not supported with FSDP2 w/ offload_params." + ) + return data + @model_validator(mode="before") @classmethod def check_use_reentrant_mismatch(cls, data): @@ -513,6 +525,20 @@ def pretrain_with_tps(cls, data): return data +class CELossValidationMixin: + """Validation methods related to CE loss configuration.""" + + @model_validator(mode="before") + @classmethod + def check_dft_loss_fn(cls, data): + if data.get("use_dynamic_finetuning"): + if not data.get("chunked_cross_entropy"): + raise ValueError( + "`use_dynamic_finetuning` requires `chunked_cross_entropy`" + ) + return data + + class LoRAValidationMixin: """Validation methods related to LoRA/QLoRA configuration.""" @@ -1326,6 +1352,7 @@ class ValidationMixin( DatasetValidationMixin, AttentionValidationMixin, TrainingValidationMixin, + CELossValidationMixin, LoRAValidationMixin, RLValidationMixin, OptimizationValidationMixin,