Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/axolotl/integrations/liger/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

# FLCE-specific
liger_use_token_scaling: true
```

## Supported Models
Expand Down
23 changes: 22 additions & 1 deletion src/axolotl/integrations/liger/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Module for handling LIGER input arguments.
"""

from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator

from axolotl.utils.logging import get_logger

Expand All @@ -35,6 +35,15 @@ class LigerArgs(BaseModel):
liger_glu_activation: bool | None = None
liger_cross_entropy: bool | None = None
liger_fused_linear_cross_entropy: bool | None = None
liger_use_token_scaling: bool | None = Field(
default=None,
json_schema_extra={
"description": (
"Enables use_token_scaling in fused_linear_cross_entropy. "
"When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
)
},
)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -75,6 +84,18 @@ def check_liger_rms_norm_tensor_parallel(cls, data):
)
return data

@model_validator(mode="before")
@classmethod
def check_liger_use_token_scaling_flce(cls, data):
if data.get("liger_use_token_scaling") and not data.get(
"liger_fused_linear_cross_entropy"
):
raise ValueError(
"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled."
)

return data

@model_validator(mode="after")
def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self):
# TODO @SalmanMohammadi this is a larger fix - investigate
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/integrations/liger/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def pre_model_load(self, cfg):
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
)

if cfg.liger_use_token_scaling:
# Patch FLCE to set token_scaling=True for function and class API
from liger_kernel.transformers import functional
from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)

old_liger_fused_linear_cross_entropy = (
functional.liger_fused_linear_cross_entropy
)

def patched_liger_fused_linear_cross_entropy(*args, **kwargs):
kwargs["use_token_scaling"] = True
return old_liger_fused_linear_cross_entropy(*args, **kwargs)

functional.liger_fused_linear_cross_entropy = (
patched_liger_fused_linear_cross_entropy
)

old_init = LigerFusedLinearCrossEntropyLoss.__init__

def patched_init(self, *args, **kwargs):
kwargs["use_token_scaling"] = True
return old_init(self, *args, **kwargs)

LigerFusedLinearCrossEntropyLoss.__init__ = patched_init

if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)
Expand Down
8 changes: 7 additions & 1 deletion tests/e2e/integrations/test_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Simple end-to-end test for Liger integration
"""

import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
Expand Down Expand Up @@ -62,7 +63,11 @@ def test_llama_wo_flce(self, temp_dir):
check_model_output_exists(temp_dir, cfg)

@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
@pytest.mark.parametrize(
"liger_use_token_scaling",
[True, False],
)
def test_llama_w_flce(self, temp_dir, liger_use_token_scaling):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
Expand All @@ -74,6 +79,7 @@ def test_llama_w_flce(self, temp_dir):
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"liger_use_token_scaling": liger_use_token_scaling,
"sequence_len": 1024,
"val_set_size": 0.05,
"special_tokens": {
Expand Down
16 changes: 16 additions & 0 deletions tests/integrations/test_liger.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,19 @@ def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

def test_use_token_scaling_require_flce(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_fused_linear_cross_entropy": False,
"liger_use_token_scaling": True,
}
| minimal_liger_cfg
)

with pytest.raises(
ValueError,
match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)