diff --git a/src/axolotl/integrations/liger/README.md b/src/axolotl/integrations/liger/README.md index c5cce8282b..3a2d4bd04b 100644 --- a/src/axolotl/integrations/liger/README.md +++ b/src/axolotl/integrations/liger/README.md @@ -18,6 +18,9 @@ liger_rms_norm: true liger_glu_activation: true liger_layer_norm: true liger_fused_linear_cross_entropy: true + +# FLCE-specific +liger_use_token_scaling: true ``` ## Supported Models diff --git a/src/axolotl/integrations/liger/args.py b/src/axolotl/integrations/liger/args.py index d5bb10cfdd..eb7a6c59be 100644 --- a/src/axolotl/integrations/liger/args.py +++ b/src/axolotl/integrations/liger/args.py @@ -16,7 +16,7 @@ Module for handling LIGER input arguments. """ -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from axolotl.utils.logging import get_logger @@ -35,6 +35,15 @@ class LigerArgs(BaseModel): liger_glu_activation: bool | None = None liger_cross_entropy: bool | None = None liger_fused_linear_cross_entropy: bool | None = None + liger_use_token_scaling: bool | None = Field( + default=None, + json_schema_extra={ + "description": ( + "Enables use_token_scaling in fused_linear_cross_entropy. " + "When True, each token's loss is multiplied by its predicted probability (detached from gradients)." + ) + }, + ) @model_validator(mode="before") @classmethod @@ -75,6 +84,18 @@ def check_liger_rms_norm_tensor_parallel(cls, data): ) return data + @model_validator(mode="before") + @classmethod + def check_liger_use_token_scaling_flce(cls, data): + if data.get("liger_use_token_scaling") and not data.get( + "liger_fused_linear_cross_entropy" + ): + raise ValueError( + "`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled." + ) + + return data + @model_validator(mode="after") def check_tensor_parallel_size_liger_fused_linear_cross_entropy(self): # TODO @SalmanMohammadi this is a larger fix - investigate diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index 89f7c37b71..ac796c2c90 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -48,6 +48,33 @@ def pre_model_load(self, cfg): "Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set." ) + if cfg.liger_use_token_scaling: + # Patch FLCE to set token_scaling=True for function and class API + from liger_kernel.transformers import functional + from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, + ) + + old_liger_fused_linear_cross_entropy = ( + functional.liger_fused_linear_cross_entropy + ) + + def patched_liger_fused_linear_cross_entropy(*args, **kwargs): + kwargs["use_token_scaling"] = True + return old_liger_fused_linear_cross_entropy(*args, **kwargs) + + functional.liger_fused_linear_cross_entropy = ( + patched_liger_fused_linear_cross_entropy + ) + + old_init = LigerFusedLinearCrossEntropyLoss.__init__ + + def patched_init(self, *args, **kwargs): + kwargs["use_token_scaling"] = True + return old_init(self, *args, **kwargs) + + LigerFusedLinearCrossEntropyLoss.__init__ = patched_init + if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] liger_fn_sig = inspect.signature(apply_liger_fn) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index 2859699633..55317151e5 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -2,6 +2,7 @@ Simple end-to-end test for Liger integration """ +import pytest from axolotl.common.datasets import load_datasets from axolotl.train import train from axolotl.utils.config import normalize_config, prepare_plugins, validate_config @@ -62,7 +63,11 @@ def test_llama_wo_flce(self, temp_dir): check_model_output_exists(temp_dir, cfg) @require_torch_2_4_1 - def test_llama_w_flce(self, temp_dir): + @pytest.mark.parametrize( + "liger_use_token_scaling", + [True, False], + ) + def test_llama_w_flce(self, temp_dir, liger_use_token_scaling): cfg = DictDefault( { "base_model": "HuggingFaceTB/SmolLM2-135M", @@ -74,6 +79,7 @@ def test_llama_w_flce(self, temp_dir): "liger_glu_activation": True, "liger_cross_entropy": False, "liger_fused_linear_cross_entropy": True, + "liger_use_token_scaling": liger_use_token_scaling, "sequence_len": 1024, "val_set_size": 0.05, "special_tokens": { diff --git a/tests/integrations/test_liger.py b/tests/integrations/test_liger.py index d7b171ec27..6865306c9d 100644 --- a/tests/integrations/test_liger.py +++ b/tests/integrations/test_liger.py @@ -75,3 +75,19 @@ def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg): ): prepare_plugins(test_cfg) validate_config(test_cfg) + + def test_use_token_scaling_require_flce(self, minimal_liger_cfg): + test_cfg = DictDefault( + { + "liger_fused_linear_cross_entropy": False, + "liger_use_token_scaling": True, + } + | minimal_liger_cfg + ) + + with pytest.raises( + ValueError, + match=r"`liger_use_token_scaling: true` requires `liger_fused_linear_cross_entropy` enabled.", + ): + prepare_plugins(test_cfg) + validate_config(test_cfg)