From 4aafd5f6a11d1626b77efe5e4731d3675a2b982a Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 2 Sep 2025 11:21:52 +0700 Subject: [PATCH 1/4] feat: add arg to enable dft in liger --- src/axolotl/integrations/liger/README.md | 3 +++ src/axolotl/integrations/liger/args.py | 11 +++++++++- src/axolotl/integrations/liger/plugin.py | 27 ++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/liger/README.md b/src/axolotl/integrations/liger/README.md index c5cce8282b..3a2d4bd04b 100644 --- a/src/axolotl/integrations/liger/README.md +++ b/src/axolotl/integrations/liger/README.md @@ -18,6 +18,9 @@ liger_rms_norm: true liger_glu_activation: true liger_layer_norm: true liger_fused_linear_cross_entropy: true + +# FLCE-specific +liger_use_token_scaling: true ``` ## Supported Models diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index d5bb10cfdd..4df74eb440 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -16,7 +16,7 @@ Module for handling LIGER input arguments. """ -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from axolotl.utils.logging import get_logger @@ -35,6 +35,15 @@ class LigerArgs(BaseModel): liger_glu_activation: bool | None = None liger_cross_entropy: bool | None = None liger_fused_linear_cross_entropy: bool | None = None + liger_use_token_scaling: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Enables use_token_scaling in fused_linear_cross_entropy. " + "When True, each token's loss is multiplied by its predicted probability (detached from gradients)." + ) + }, + ) @model_validator(mode="before") @classmethod diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index 89f7c37b71..26fa703283 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -48,6 +48,33 @@ def pre_model_load(self, cfg): "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." ) + if cfg.liger_fused_linear_cross_entropy and cfg.liger_use_token_scaling: + # Patch FLCE to set token_scaling=True for function and class API + from liger_kernel.transformers import functional + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) + + old_liger_fused_linear_cross_entropy = ( + functional.liger_fused_linear_cross_entropy + ) + + def patched_liger_fused_linear_cross_entropy(*args, **kwargs): + kwargs["use_token_scaling"] = True + return old_liger_fused_linear_cross_entropy(*args, **kwargs) + + functional.liger_fused_linear_cross_entropy = ( + patched_liger_fused_linear_cross_entropy + ) + + old_init = LigerFusedLinearCrossEntropyLoss.__init__ + + def patched_init(self, *args, **kwargs): + kwargs["use_token_scaling"] = True + return old_init(self, *args, **kwargs) + + LigerFusedLinearCrossEntropyLoss.__init__ = patched_init + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] liger_fn_sig = inspect.signature(apply_liger_fn) From 518c4c882b5db36de59c73924a36d53a3d9862ce Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 2 Sep 2025 11:38:54 +0700 Subject: [PATCH 2/4] feat: add tests use_token_scaling --- src/axolotl/integrations/liger/plugin.py | 7 ++++++- tests/e2e/integrations/test_liger.py | 8 +++++++- tests/integrations/test_liger.py | 16 ++++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index 26fa703283..9ff12312a0 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -48,7 +48,12 @@ def pre_model_load(self, cfg): "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." ) - if cfg.liger_fused_linear_cross_entropy and cfg.liger_use_token_scaling: + if cfg.liger_use_token_scaling: + if not cfg.liger_fused_linear_cross_entropy: + raise ValueError( + "`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled." + ) + # Patch FLCE to set token_scaling=True for function and class API from liger_kernel.transformers import functional from liger_kernel.transformers.fused_linear_cross_entropy import ( diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 2859699633..55317151e5 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -2,6 +2,7 @@ Simple end-to-end test for Liger integration """ +import pytest from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -62,7 +63,11 @@ def test_llama_wo_flce(self, temp_dir): check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 - def test_llama_w_flce(self, temp_dir): + @pytest.mark.parametrize( + "liger_use_token_scaling", + [True, False], + ) + def test_llama_w_flce(self, temp_dir, liger_use_token_scaling): cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -74,6 +79,7 @@ def test_llama_w_flce(self, temp_dir): "liger_glu_activation": True, "liger_cross_entropy": False, "liger_fused_linear_cross_entropy": True, + "liger_use_token_scaling": liger_use_token_scaling, "sequence_len": 1024, "val_set_size": 0.05, "special_tokens": { diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index d7b171ec27..964ad86e2a 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -75,3 +75,19 @@ def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg): ): prepare_plugins(test_cfg) validate_config(test_cfg) + + def test_use_token_scaling_require_flce(self, minimal_liger_cfg): + test_cfg = DictDefault( + { + "liger_fused_linear_cross_entropy": False, + "liger_use_token_scaling": True, + } + | minimal_liger_cfg + ) + + with pytest.raises( + ValueError, + match=r".*`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.*", + ): + prepare_plugins(test_cfg) + validate_config(test_cfg) From eeb17e7216a19bed98ac9acac0710a58c8018773 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 2 Sep 2025 12:01:10 +0700 Subject: [PATCH 3/4] fix: test --- tests/integrations/test_liger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index 964ad86e2a..6865306c9d 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -87,7 +87,7 @@ def test_use_token_scaling_require_flce(self, minimal_liger_cfg): with pytest.raises( ValueError, - match=r".*`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.*", + match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.", ): prepare_plugins(test_cfg) validate_config(test_cfg) From 0b2795f3786fbd4bc4ea4a0017d2b68de38039d0 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Tue, 2 Sep 2025 13:14:41 +0700 Subject: [PATCH 4/4] fix: move check to args --- src/axolotl/integrations/liger/args.py | 12 ++++++++++++ src/axolotl/integrations/liger/plugin.py | 5 ----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index 4df74eb440..eb7a6c59be 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -84,6 +84,18 @@ def check_liger_rms_norm_tensor_parallel(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_liger_use_token_scaling_flce(cls, data): + if data.get("liger_use_token_scaling") and not data.get( + "liger_fused_linear_cross_entropy" + ): + raise ValueError( + "`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled." + ) + + return data + @model_validator(mode="after") def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self): # TODO @SalmanMohammadi this is a larger fix - investigate diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index 9ff12312a0..ac796c2c90 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -49,11 +49,6 @@ def pre_model_load(self, cfg): ) if cfg.liger_use_token_scaling: - if not cfg.liger_fused_linear_cross_entropy: - raise ValueError( - "`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled." - ) - # Patch FLCE to set token_scaling=True for function and class API from liger_kernel.transformers import functional from liger_kernel.transformers.fused_linear_cross_entropy import (