Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/axolotl/integrations/liger/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

# FLCE-specific
liger_use_token_scaling: true
```

## Supported Models
Expand Down
11 changes: 10 additions & 1 deletion src/axolotl/integrations/liger/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Module for handling LIGER input arguments.
"""

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator

from axolotl.utils.logging import get_logger

Expand All @@ -35,6 +35,15 @@ class LigerArgs(BaseModel):
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
liger_use_token_scaling: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables use_token_scaling in fused_linear_cross_entropy. "
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
)
},
)

@model_validator(mode="before")
@classmethod
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/integrations/liger/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def pre_model_load(self, cfg):
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)

if cfg.liger_fused_linear_cross_entropy and cfg.liger_use_token_scaling:
# Patch FLCE to set token_scaling=True for function and class API
from liger_kernel.transformers import functional
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)

old_liger_fused_linear_cross_entropy = (
functional.liger_fused_linear_cross_entropy
)

def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
kwargs["use_token_scaling"] = True
return old_liger_fused_linear_cross_entropy(*args, **kwargs)

functional.liger_fused_linear_cross_entropy = (
patched_liger_fused_linear_cross_entropy
)

old_init = LigerFusedLinearCrossEntropyLoss.__init__

def patched_init(self, *args, **kwargs):
kwargs["use_token_scaling"] = True
return old_init(self, *args, **kwargs)

LigerFusedLinearCrossEntropyLoss.__init__ = patched_init

if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
Expand Down
Loading