Skip to content

Commit 8e231a3

Browse files
Manan17lancertsManan Shahvaibhavjindal
authored
Liger support for Llama4 (#740)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Adding liger support to llama4 (text and multimodal). Liger rope is not supported yet for llama4 as it has a different implementation. The tolerance are set higher than usual to pass the tests. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: NVIDIA H100 80GB HBM3 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]> Co-authored-by: Manan Shah <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent d47751c commit 8e231a3

File tree

13 files changed

+905
-27
lines changed

13 files changed

+905
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,7 @@ loss.backward()
241241

242242
| **Model** | **API** | **Supported Operations** |
243243
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
244+
| Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
244245
| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
245246
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
246247
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
3131
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
3232
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
33+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
3334
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llava # noqa: F401
3435
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mistral # noqa: F401
3536
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
@@ -87,6 +88,7 @@ def __getattr__(name: str):
8788
"apply_liger_kernel_to_granite",
8889
"apply_liger_kernel_to_llama",
8990
"apply_liger_kernel_to_llava",
91+
"apply_liger_kernel_to_llama4",
9092
"apply_liger_kernel_to_mistral",
9193
"apply_liger_kernel_to_mixtral",
9294
"apply_liger_kernel_to_mllama",
@@ -141,6 +143,7 @@ def __getattr__(name: str):
141143
"apply_liger_kernel_to_granite",
142144
"apply_liger_kernel_to_llama",
143145
"apply_liger_kernel_to_llava",
146+
"apply_liger_kernel_to_llama4",
144147
"apply_liger_kernel_to_mistral",
145148
"apply_liger_kernel_to_mixtral",
146149
"apply_liger_kernel_to_mllama",
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from typing import List
2+
from typing import Optional
3+
from typing import Tuple
4+
from typing import Union
5+
6+
import torch
7+
8+
from transformers.cache_utils import Cache
9+
from transformers.modeling_outputs import CausalLMOutputWithPast
10+
11+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12+
13+
14+
def lce_forward(
15+
self,
16+
input_ids: torch.LongTensor = None,
17+
attention_mask: Optional[torch.Tensor] = None,
18+
position_ids: Optional[torch.LongTensor] = None,
19+
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
20+
inputs_embeds: Optional[torch.FloatTensor] = None,
21+
labels: Optional[torch.LongTensor] = None,
22+
use_cache: Optional[bool] = None,
23+
output_attentions: Optional[bool] = None,
24+
output_hidden_states: Optional[bool] = None,
25+
return_dict: Optional[bool] = None,
26+
cache_position: Optional[torch.LongTensor] = None,
27+
logits_to_keep: Union[int, torch.Tensor] = 0,
28+
**kwargs,
29+
) -> Union[Tuple, CausalLMOutputWithPast]:
30+
r"""
31+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
32+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
33+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
34+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
35+
36+
Example:
37+
38+
```python
39+
>>> from transformers import AutoTokenizer, Llama4ForCausalLM
40+
41+
>>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf")
42+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf")
43+
44+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
45+
>>> inputs = tokenizer(prompt, return_tensors="pt")
46+
47+
>>> # Generate
48+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
49+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
50+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
51+
```"""
52+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
53+
output_hidden_states = (
54+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
55+
)
56+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
57+
58+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
59+
outputs = self.model(
60+
input_ids=input_ids,
61+
attention_mask=attention_mask,
62+
position_ids=position_ids,
63+
past_key_values=past_key_values,
64+
inputs_embeds=inputs_embeds,
65+
use_cache=use_cache,
66+
output_attentions=output_attentions,
67+
output_hidden_states=output_hidden_states,
68+
return_dict=True,
69+
cache_position=cache_position,
70+
**kwargs,
71+
)
72+
73+
hidden_states = outputs[0]
74+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
75+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
76+
kept_hidden_states = hidden_states[:, slice_indices, :]
77+
78+
shift_labels = kwargs.pop("shift_labels", None)
79+
logits = None
80+
loss = None
81+
82+
if self.training and (labels is not None or shift_labels is not None):
83+
loss = LigerForCausalLMLoss(
84+
hidden_states=kept_hidden_states,
85+
lm_head_weight=self.lm_head.weight,
86+
labels=labels,
87+
shift_labels=shift_labels,
88+
hidden_size=self.config.hidden_size,
89+
**kwargs,
90+
)
91+
92+
else: # if in inference mode materialize logits
93+
logits = self.lm_head(kept_hidden_states)
94+
if labels is not None:
95+
loss = self.loss_function(
96+
logits=logits,
97+
labels=labels,
98+
vocab_size=self.config.vocab_size,
99+
**kwargs,
100+
)
101+
102+
return CausalLMOutputWithPast(
103+
loss=loss,
104+
logits=logits,
105+
past_key_values=outputs.past_key_values,
106+
hidden_states=outputs.hidden_states,
107+
attentions=outputs.attentions,
108+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,92 @@ def apply_liger_kernel_to_llava(
363363
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
364364

365365

366+
def apply_liger_kernel_to_llama4(
367+
rope: bool = False,
368+
cross_entropy: bool = False,
369+
fused_linear_cross_entropy: bool = True,
370+
rms_norm: bool = True,
371+
swiglu: bool = True,
372+
model: PreTrainedModel = None,
373+
layer_norm: bool = True,
374+
) -> None:
375+
"""
376+
Apply Liger kernels to replace original implementation in HuggingFace Llama4 models.
377+
378+
Args:
379+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
380+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
381+
fused_linear_cross_entropy (bool):
382+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
383+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
384+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
385+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
386+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
387+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
388+
loaded. Default is None.
389+
"""
390+
assert not (cross_entropy and fused_linear_cross_entropy), (
391+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
392+
)
393+
394+
from transformers.models.llama4 import modeling_llama4
395+
from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM
396+
from transformers.models.llama4.modeling_llama4 import Llama4ForConditionalGeneration
397+
from transformers.models.llama4.modeling_llama4 import Llama4TextModel
398+
from transformers.models.llama4.modeling_llama4 import Llama4VisionModel
399+
400+
from liger_kernel.transformers.model.llama4 import lce_forward as llama4_lce_forward
401+
402+
if rope:
403+
raise NotImplementedError("liger_rotary_pos_emb is not available for Llama4 models.")
404+
if rms_norm:
405+
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
406+
if swiglu:
407+
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
408+
409+
if cross_entropy:
410+
modeling_llama4.CrossEntropyLoss = LigerCrossEntropyLoss
411+
412+
if fused_linear_cross_entropy:
413+
modeling_llama4.Llama4ForCausalLM.forward = llama4_lce_forward
414+
415+
if model is not None:
416+
# The model instance already exists, so we need to additionally patch the
417+
# instance variables that reference already-instantiated modules
418+
if isinstance(model, Llama4ForConditionalGeneration):
419+
language_model: Llama4ForCausalLM = model.language_model
420+
vision_model: Llama4VisionModel = model.vision_model
421+
text_model: Llama4TextModel = language_model.model
422+
elif isinstance(model, Llama4ForCausalLM):
423+
text_model = model.model
424+
vision_model = None
425+
elif isinstance(model, Llama4TextModel):
426+
text_model = model
427+
vision_model = None
428+
429+
else:
430+
raise ValueError(f"Unsupported Llama4 model type: {type(model)}")
431+
432+
if text_model:
433+
if rms_norm:
434+
_patch_rms_norm_module(text_model.norm)
435+
for decoder_layer in text_model.layers:
436+
if swiglu:
437+
_patch_swiglu_module(decoder_layer.feed_forward, LigerSwiGLUMLP)
438+
if rms_norm:
439+
_patch_rms_norm_module(decoder_layer.input_layernorm)
440+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
441+
442+
if vision_model:
443+
_patch_layer_norm_module(vision_model.layernorm_pre)
444+
_patch_layer_norm_module(vision_model.layernorm_post)
445+
446+
for layer in vision_model.model.layers:
447+
if layer_norm:
448+
_patch_layer_norm_module(layer.input_layernorm)
449+
_patch_layer_norm_module(layer.post_attention_layernorm)
450+
451+
366452
def apply_liger_kernel_to_mllama(
367453
rope: bool = True,
368454
cross_entropy: bool = False,
@@ -1605,6 +1691,8 @@ def apply_liger_kernel_to_glm4(
16051691
"gemma3": apply_liger_kernel_to_gemma3,
16061692
"glm4": apply_liger_kernel_to_glm4,
16071693
"llama": apply_liger_kernel_to_llama,
1694+
"llama4_text": apply_liger_kernel_to_llama4,
1695+
"llama4": apply_liger_kernel_to_llama4,
16081696
"llava": apply_liger_kernel_to_llava,
16091697
"granite": apply_liger_kernel_to_granite,
16101698
"mllama": apply_liger_kernel_to_mllama,

test/convergence/bf16/test_mini_models.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from transformers.models.gemma2 import Gemma2ForCausalLM
1010
from transformers.models.llama import LlamaConfig
1111
from transformers.models.llama import LlamaForCausalLM
12+
from transformers.models.llama4 import Llama4ForCausalLM
13+
from transformers.models.llama4.configuration_llama4 import Llama4TextConfig
1214
from transformers.models.mistral import MistralConfig
1315
from transformers.models.mistral import MistralForCausalLM
1416
from transformers.models.mixtral import MixtralConfig
@@ -24,6 +26,7 @@
2426
from liger_kernel.transformers import apply_liger_kernel_to_glm4
2527
from liger_kernel.transformers import apply_liger_kernel_to_granite
2628
from liger_kernel.transformers import apply_liger_kernel_to_llama
29+
from liger_kernel.transformers import apply_liger_kernel_to_llama4
2730
from liger_kernel.transformers import apply_liger_kernel_to_llava
2831
from liger_kernel.transformers import apply_liger_kernel_to_mistral
2932
from liger_kernel.transformers import apply_liger_kernel_to_mixtral
@@ -46,6 +49,7 @@
4649
from test.utils import revert_liger_kernel_to_glm4
4750
from test.utils import revert_liger_kernel_to_granite
4851
from test.utils import revert_liger_kernel_to_llama
52+
from test.utils import revert_liger_kernel_to_llama4
4953
from test.utils import revert_liger_kernel_to_llava
5054
from test.utils import revert_liger_kernel_to_mistral
5155
from test.utils import revert_liger_kernel_to_mixtral
@@ -152,6 +156,35 @@
152156
device = infer_device()
153157

154158
MINI_MODEL_SETUPS = {
159+
"mini_llama4": MiniModelConfig(
160+
liger_kernel_patch_func=apply_liger_kernel_to_llama4,
161+
liger_kernel_patch_revert_func=revert_liger_kernel_to_llama4,
162+
model_class=Llama4ForCausalLM,
163+
mini_model_config=Llama4TextConfig(
164+
bos_token_id=1, # None
165+
eos_token_id=2, # 151329, 151336, 151338
166+
pad_token_id=2, # 151329
167+
partial_rotary_factor=1.0,
168+
cross_attention_layers=None,
169+
dropout=0,
170+
hidden_act="silu",
171+
hidden_size=1024, # 6144
172+
initializer_range=0.02,
173+
intermediate_size=2048, # 14336
174+
max_position_embeddings=4096, # 32768
175+
num_attention_heads=8, # 48
176+
num_hidden_layers=4, # 61
177+
num_key_value_heads=2,
178+
rms_norm_eps=1e-5,
179+
rope_scaling=None,
180+
rope_theta=10000.0,
181+
tie_word_embeddings=False,
182+
use_cache=True,
183+
vocab_size=32000, # 151552
184+
attention_bias=True,
185+
attn_implementation="sdpa", # default value, pytorch native attention
186+
),
187+
),
155188
"mini_llama3": MiniModelConfig(
156189
liger_kernel_patch_func=apply_liger_kernel_to_llama,
157190
liger_kernel_patch_revert_func=revert_liger_kernel_to_llama,
@@ -380,6 +413,7 @@
380413
),
381414
}
382415

416+
383417
if QWEN3_AVAILABLE:
384418
MINI_MODEL_SETUPS["mini_qwen3"] = MiniModelConfig(
385419
liger_kernel_patch_func=apply_liger_kernel_to_qwen3,
@@ -770,7 +804,7 @@
770804
)
771805

772806

773-
def create_model(model_name="mini_llama3"):
807+
def create_model(model_name="mini_llama4"):
774808
"""
775809
Create a mini version model
776810
The commented values are the original values
@@ -781,7 +815,7 @@ def create_model(model_name="mini_llama3"):
781815

782816

783817
def run_mini_model(
784-
model_name="mini_llama3",
818+
model_name="mini_llama4",
785819
num_steps=100,
786820
dtype=torch.bfloat16,
787821
lr=1e-5,
@@ -804,7 +838,7 @@ def run_mini_model(
804838
"rms_norm": True,
805839
}
806840

807-
if "glm4" in model_name:
841+
if "glm4" in model_name or "llama4" in model_name:
808842
kwargs["rope"] = False
809843

810844
model_supports_layer_norm = "qwen2_vl" in model_name
@@ -865,6 +899,19 @@ def run_mini_model(
865899
@pytest.mark.parametrize(
866900
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logprobs_atol, logprobs_rtol, param_atol, param_rtol",
867901
[
902+
pytest.param(
903+
"mini_llama4",
904+
32,
905+
1e-4,
906+
torch.bfloat16,
907+
1e-3,
908+
1e-2,
909+
1e-1,
910+
1e-1,
911+
1e-2,
912+
1e-2,
913+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
914+
),
868915
pytest.param(
869916
"mini_llama3",
870917
32,

0 commit comments

Comments
 (0)