Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,12 @@ def _apply_chunked_cross_entropy_patch(self):
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn

if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
patch_chunked_ce_loss_fn(
self.cfg.chunked_cross_entropy_num_chunks,
use_dft=self.cfg.use_dynamic_finetuning,
)
else:
patch_chunked_ce_loss_fn()
patch_chunked_ce_loss_fn(use_dft=self.cfg.use_dynamic_finetuning)

def _apply_fsdp_patches(self):
"""Apply patches for FSDP configurations."""
Expand Down
52 changes: 43 additions & 9 deletions src/axolotl/monkeypatch/loss/chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@ class CEWithChunkedOutputLoss(torch.nn.Module):
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""

def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
def __init__(
self,
num_output_chunks: int = 8,
ignore_index: int = -100,
use_dft: bool = False,
):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.use_dft = use_dft

def compute_cross_entropy(
self,
Expand All @@ -30,10 +36,30 @@ def compute_cross_entropy(
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
ce_loss = F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="none"
)

if self.use_dft:
# Compute probabilities and gather the ones corresponding to labels
with torch.no_grad(): # Stop gradient
probs = torch.softmax(logits.float(), dim=-1)
# Create mask for valid tokens (not ignore_index)
valid_mask = labels != self.ignore_index
# Gather probabilities for the correct tokens
label_probs = probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
# Apply mask to only scale valid tokens
label_probs = label_probs * valid_mask
# Avoid multiplication by 0 for ignored tokens
label_probs = torch.where(
valid_mask, label_probs, torch.ones_like(label_probs)
)

# Scale the loss by the probability (DFT)
ce_loss = ce_loss * label_probs

return ce_loss.sum()

def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
Expand Down Expand Up @@ -71,16 +97,20 @@ def forward(
return total_loss / total_elements


def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
def _build_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index, use_dft)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce


def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def get_causal_lm_loss(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index, use_dft)

def chunked_fix_cross_entropy(
source,
Expand Down Expand Up @@ -124,10 +154,14 @@ def for_causal_lm_chunked_loss(
return for_causal_lm_chunked_loss


def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
def patch_chunked_ce_loss_fn(
num_output_chunks: int = 8, ignore_index: int = -100, use_dft: bool = False
):
import transformers.loss.loss_utils

for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
for_causal_lm_chunked_loss = get_causal_lm_loss(
num_output_chunks, ignore_index, use_dft
)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss
Expand Down
7 changes: 7 additions & 0 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,13 @@ class AxolotlInputConfig(
},
)

use_dynamic_finetuning: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use dynamic fine-tuning for scaled SFT gradients."
},
)

chunked_cross_entropy: bool | None = Field(
default=None,
json_schema_extra={
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,18 @@ def check_fp8_config(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_ao_optim_fsdp2_offload(cls, data):
if data.get("fsdp_config") and data.get("fsdp_config", {}).get(
"offload_params"
):
if data.get("optimizer") in ["adamw_torch_8bit", "adamw_torch_4bit"]:
raise ValueError(
"low bit ao optimizers is not supported with FSDP2 w/ offload_params."
)
return data

@model_validator(mode="before")
@classmethod
def check_use_reentrant_mismatch(cls, data):
Expand Down Expand Up @@ -513,6 +525,20 @@ def pretrain_with_tps(cls, data):
return data


class CELossValidationMixin:
"""Validation methods related to CE loss configuration."""

@model_validator(mode="before")
@classmethod
def check_dft_loss_fn(cls, data):
if data.get("use_dynamic_finetuning"):
if not data.get("chunked_cross_entropy"):
raise ValueError(
"`use_dynamic_finetuning` requires `chunked_cross_entropy`"
)
return data


class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""

Expand Down Expand Up @@ -1326,6 +1352,7 @@ class ValidationMixin(
DatasetValidationMixin,
AttentionValidationMixin,
TrainingValidationMixin,
CELossValidationMixin,
LoRAValidationMixin,
RLValidationMixin,
OptimizationValidationMixin,
Expand Down
Loading