Skip to content

Commit 8b4c62f

Browse files
Tcc0403lancerts
andauthored
Add AutoLigerKernelForCausalLM.from_config (#962)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Fix #943 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Tcc0403 <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent 6059dfb commit 8b4c62f

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

src/liger_kernel/transformers/auto_model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import inspect
2+
import logging
23

34
from transformers import AutoConfig
45
from transformers import AutoModelForCausalLM
56

67
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
78
from liger_kernel.transformers.monkey_patch import _apply_liger_kernel
89

10+
logger = logging.getLogger(__name__)
11+
912

1013
def _get_model_config(model_dir, **model_init_kwargs):
1114
config = AutoConfig.from_pretrained(model_dir, **model_init_kwargs)
@@ -36,3 +39,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
3639
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
3740

3841
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **applicable_kwargs)
42+
43+
@classmethod
44+
def from_config(cls, config, **kwargs):
45+
model_type = getattr(config, "model_type", None)
46+
if not model_type:
47+
logger.info("Model type could not be determined from model config. No Liger kernels will be applied.")
48+
return
49+
model_type = config.model_type
50+
51+
_apply_liger_kernel(model_type, **kwargs)
52+
53+
# Filter out kwargs that were passed to the apply_liger_* function, which will cause
54+
# model initialization errors otherwise
55+
apply_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[model_type]
56+
apply_fn_signature = inspect.signature(apply_fn)
57+
applicable_kwargs = {key: value for key, value in kwargs.items() if key not in apply_fn_signature.parameters}
58+
59+
return super().from_config(config, **applicable_kwargs)

test/transformers/test_auto_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,38 @@ def test_auto_liger_kernel_for_causal_lm_from_pretrained():
5252
pretrained_model_name_or_path, *model_args, **original_kwargs
5353
)
5454
assert model == "mock_model"
55+
56+
57+
def test_auto_liger_kernel_for_causal_lm_from_config():
58+
original_kwargs = {
59+
"valid_arg_1": "some_value_1",
60+
"valid_arg_2": 10,
61+
}
62+
63+
# These args should be passed through to apply_liger_kernel_to_llama fn
64+
apply_liger_kernel_kwargs = {
65+
"rope": False,
66+
"swiglu": True,
67+
}
68+
69+
kwargs = {**original_kwargs, **apply_liger_kernel_kwargs}
70+
71+
# Mock the model config instance returned from AutoConfig.from_pretrained()
72+
mock_model_config = MagicMock()
73+
mock_model_config.model_type = "llama"
74+
mock_llama = mock.Mock()
75+
76+
with (
77+
patch.dict(MODEL_TYPE_TO_APPLY_LIGER_FN, {"llama": mock_llama}),
78+
mock.patch.object(AutoModelForCausalLM, "from_config", return_value="mock_model") as mock_super_from_config,
79+
):
80+
# Mock the function signature of apply_liger_kernel_to_llama
81+
mock_llama.__signature__ = signature(apply_liger_kernel_to_llama)
82+
83+
model = AutoLigerKernelForCausalLM.from_config(mock_model_config, **kwargs)
84+
85+
# Check that the apply_liger_kernel_to_llama mock was called with the correct kwargs
86+
mock_llama.assert_called_once_with(rope=False, swiglu=True)
87+
# Check that the original kwargs are passed to super().from_pretrained
88+
mock_super_from_config.assert_called_once_with(mock_model_config, **original_kwargs)
89+
assert model == "mock_model"

0 commit comments

Comments
 (0)