Skip to content

Commit 30da53d

Browse files
[Model] Liger support for SmolLM3 (#798)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR adds Liger support for the SmolLM3 model. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> The model architecture is similar to that of llama, so I was able to use most of their ops. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: A100-80G-PCIe - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Steven Shimizu <[email protected]>
1 parent fc8fd33 commit 30da53d

File tree

9 files changed

+581
-0
lines changed

9 files changed

+581
-0
lines changed

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
4444
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3 # noqa: F401
4545
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen3_moe # noqa: F401
46+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_smollm3 # noqa: F401
4647

4748

4849
# Check if 'transformers' is installed
@@ -100,6 +101,7 @@ def __getattr__(name: str):
100101
"apply_liger_kernel_to_qwen2_vl",
101102
"apply_liger_kernel_to_qwen3",
102103
"apply_liger_kernel_to_qwen3_moe",
104+
"apply_liger_kernel_to_smollm3",
103105
}
104106

105107
if name in monkey_patch_symbols:
@@ -155,5 +157,6 @@ def __getattr__(name: str):
155157
"apply_liger_kernel_to_qwen2_vl",
156158
"apply_liger_kernel_to_qwen3",
157159
"apply_liger_kernel_to_qwen3_moe",
160+
"apply_liger_kernel_to_smollm3",
158161
]
159162
)
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from typing import TYPE_CHECKING
2+
from typing import List
3+
from typing import Optional
4+
from typing import Tuple
5+
from typing import Union
6+
7+
import torch
8+
9+
from torch.distributed.fsdp import FullyShardedDataParallel
10+
from transformers.modeling_outputs import CausalLMOutputWithPast
11+
from transformers.utils.deprecation import deprecate_kwarg
12+
13+
from liger_kernel.transformers.fsdp import _FSDPForwardRedirection
14+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
15+
from liger_kernel.utils import PEFT_AVAILABLE
16+
17+
if TYPE_CHECKING:
18+
from transformers.cache_utils import Cache
19+
20+
if PEFT_AVAILABLE:
21+
from peft.utils.other import ModulesToSaveWrapper
22+
23+
24+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
25+
def lce_forward(
26+
self,
27+
input_ids: torch.LongTensor = None,
28+
attention_mask: Optional[torch.Tensor] = None,
29+
position_ids: Optional[torch.LongTensor] = None,
30+
past_key_values: Optional[Union["Cache", List[torch.FloatTensor]]] = None,
31+
inputs_embeds: Optional[torch.FloatTensor] = None,
32+
labels: Optional[torch.LongTensor] = None,
33+
use_cache: Optional[bool] = None,
34+
output_attentions: Optional[bool] = None,
35+
output_hidden_states: Optional[bool] = None,
36+
return_dict: Optional[bool] = None,
37+
cache_position: Optional[torch.LongTensor] = None,
38+
logits_to_keep: Union[int, torch.Tensor] = 0,
39+
skip_logits: Optional[bool] = None,
40+
**kwargs,
41+
) -> Union[Tuple, CausalLMOutputWithPast]:
42+
r"""
43+
Args:
44+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
45+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
46+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
47+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
48+
49+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
50+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
51+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
52+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
53+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
54+
This is useful when using packed tensor format (single dimension for batch and sequence length).
55+
56+
Returns:
57+
58+
Example:
59+
60+
```python
61+
>>> from transformers import AutoTokenizer, Smollm3ForCausalLM
62+
63+
>>> model = Smollm3ForCausalLM.from_pretrained("HuggingFaceTB/SmolLM3-3B")
64+
>>> tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM3-3B")
65+
66+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
67+
>>> inputs = tokenizer(prompt, return_tensors="pt")
68+
69+
>>> # Generate
70+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
71+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
72+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
73+
```"""
74+
75+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
76+
output_hidden_states = (
77+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
78+
)
79+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
80+
81+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82+
outputs = self.model(
83+
input_ids=input_ids,
84+
attention_mask=attention_mask,
85+
position_ids=position_ids,
86+
past_key_values=past_key_values,
87+
inputs_embeds=inputs_embeds,
88+
use_cache=use_cache,
89+
output_attentions=output_attentions,
90+
output_hidden_states=output_hidden_states,
91+
return_dict=return_dict,
92+
cache_position=cache_position,
93+
**kwargs,
94+
)
95+
96+
hidden_states = outputs[0]
97+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
98+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
99+
kept_hidden_states = hidden_states[:, slice_indices, :]
100+
101+
shift_labels = kwargs.pop("shift_labels", None)
102+
logits = None
103+
loss = None
104+
# if in training mode, don't materialize logits
105+
if skip_logits and labels is None and shift_labels is None:
106+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
107+
108+
if skip_logits is None:
109+
# By default, if in training mode, don't materialize logits
110+
skip_logits = self.training and (labels is not None or shift_labels is not None)
111+
112+
if skip_logits:
113+
loss = lce_maybe_trainable_lm_head(
114+
self,
115+
hidden_states=kept_hidden_states,
116+
hidden_size=self.config.hidden_size,
117+
labels=labels,
118+
shift_labels=shift_labels,
119+
**kwargs,
120+
)
121+
122+
else:
123+
logits = self.lm_head(kept_hidden_states)
124+
if labels is not None:
125+
loss = self.loss_function(
126+
logits=logits,
127+
labels=labels,
128+
vocab_size=self.config.vocab_size,
129+
**kwargs,
130+
)
131+
132+
if not return_dict:
133+
output = (logits,) + outputs[1:]
134+
return (loss,) + output if loss is not None else output
135+
136+
return CausalLMOutputWithPast(
137+
loss=loss,
138+
logits=logits,
139+
past_key_values=outputs.past_key_values,
140+
hidden_states=outputs.hidden_states,
141+
attentions=outputs.attentions,
142+
)
143+
144+
145+
def lce_maybe_trainable_lm_head(self, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
146+
lm_head = self.lm_head
147+
148+
# Unwrap the module if lm_head has been added as trainable module in PEFT LoRA configuration,
149+
# i.e. listed in the modules_to_save field of LoraConfig, so the lm_head weights are read
150+
# from the unwrapped module.
151+
# See https://huggingface.co/docs/peft/package_reference/lora for reference.
152+
if PEFT_AVAILABLE and isinstance(lm_head, ModulesToSaveWrapper):
153+
lm_head = lm_head.modules_to_save.default
154+
155+
# If FSDP is used and lm_head is trainable, e.g., during full fine-tuning or with LoRA,
156+
# reading the lm_head module weights and calling the kernel must be done within FSDP forward pass
157+
# so the module entire parameters are summoned and kept in memory during the kernel execution.
158+
if isinstance(lm_head, FullyShardedDataParallel):
159+
return _FSDPForwardRedirection()(
160+
lm_head,
161+
_liger_for_causal_lm_loss,
162+
lm_head.module,
163+
hidden_states,
164+
hidden_size,
165+
labels,
166+
shift_labels,
167+
**loss_kwargs,
168+
)
169+
170+
# FSDP is not used so we can read the lm_head weights and call the kernel directly
171+
return _liger_for_causal_lm_loss(
172+
lm_head=self.lm_head,
173+
hidden_states=hidden_states,
174+
hidden_size=hidden_size,
175+
labels=labels,
176+
shift_labels=shift_labels,
177+
**loss_kwargs,
178+
)
179+
180+
181+
def _liger_for_causal_lm_loss(lm_head, hidden_states, hidden_size, labels, shift_labels, **loss_kwargs):
182+
return LigerForCausalLMLoss(
183+
hidden_states=hidden_states,
184+
lm_head_weight=lm_head.weight,
185+
labels=labels,
186+
hidden_size=hidden_size,
187+
shift_labels=shift_labels,
188+
**loss_kwargs,
189+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
3030
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
3131
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
32+
from liger_kernel.transformers.model.smollm3 import lce_forward as smollm3_lce_forward
3233
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
3334
from liger_kernel.transformers.rms_norm import LigerRMSNorm
3435
from liger_kernel.transformers.rope import liger_rotary_pos_emb
@@ -290,6 +291,77 @@ def apply_liger_kernel_to_llama(
290291
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
291292

292293

294+
def apply_liger_kernel_to_smollm3(
295+
rope: bool = True,
296+
cross_entropy: bool = False,
297+
fused_linear_cross_entropy: bool = True,
298+
rms_norm: bool = True,
299+
swiglu: bool = True,
300+
model: PreTrainedModel = None,
301+
) -> None:
302+
"""
303+
Apply Liger kernels to replace original implementation in HuggingFace SmolLM3 model
304+
305+
Args:
306+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
307+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
308+
fused_linear_cross_entropy (bool):
309+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
310+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
311+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
312+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
313+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
314+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
315+
loaded. Default is None.
316+
"""
317+
318+
assert not (cross_entropy and fused_linear_cross_entropy), (
319+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
320+
)
321+
322+
from transformers.models.smollm3 import modeling_smollm3
323+
from transformers.models.smollm3.modeling_smollm3 import SmolLM3Model
324+
325+
if rope:
326+
modeling_smollm3.apply_rotary_pos_emb = liger_rotary_pos_emb
327+
if rms_norm:
328+
modeling_smollm3.SmolLM3RMSNorm = LigerRMSNorm
329+
if swiglu:
330+
modeling_smollm3.SmolLM3MLP = LigerSwiGLUMLP
331+
332+
if cross_entropy:
333+
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
334+
from transformers.loss.loss_utils import nn
335+
336+
nn.functional.cross_entropy = liger_cross_entropy
337+
else:
338+
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
339+
modeling_smollm3.CrossEntropyLoss = LigerCrossEntropyLoss
340+
341+
if fused_linear_cross_entropy:
342+
if model is not None:
343+
model.forward = MethodType(smollm3_lce_forward, model)
344+
else:
345+
modeling_smollm3.SmolLM3ForCausalLM.forward = smollm3_lce_forward
346+
347+
if model is not None:
348+
# The model instance already exists, so we need to additionally patch the
349+
# instance variables that reference already-instantiated modules (e.g. SmolLM3RMSNorm or SmolLM3MLP)
350+
351+
# get the base model from the model instance
352+
base_model: SmolLM3Model = getattr(model, model.base_model_prefix, model)
353+
354+
if rms_norm:
355+
_patch_rms_norm_module(base_model.norm)
356+
357+
for decoder_layer in base_model.layers:
358+
if swiglu:
359+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
360+
if rms_norm:
361+
_patch_rms_norm_module(decoder_layer.input_layernorm)
362+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
363+
364+
293365
def apply_liger_kernel_to_llava(
294366
cross_entropy: bool = False,
295367
fused_linear_cross_entropy: bool = True,
@@ -1801,6 +1873,7 @@ def apply_liger_kernel_to_glm4(
18011873
"qwen2_vl_text": apply_liger_kernel_to_qwen2_vl,
18021874
"qwen2_5_vl": apply_liger_kernel_to_qwen2_5_vl,
18031875
"qwen2_5_vl_text": apply_liger_kernel_to_qwen2_5_vl,
1876+
"smollm3": apply_liger_kernel_to_smollm3,
18041877
"phi3": apply_liger_kernel_to_phi3,
18051878
"paligemma": apply_liger_kernel_to_paligemma,
18061879
}

0 commit comments

Comments
 (0)