Skip to content

Commit 8a93398

Browse files
Olmo3 model support [ready for review] (#946)
## Summary Add support for Olmo 3 models: https://huggingface.co/allenai/Olmo-3-1125-32B ## Details Olmo 3 is similar to Olmo 3 in all of the ways that matter to Liger-kernel. It is a separate model type in `transformers` (in order to support SWA) so instead of just aliasing Olmo 3 to Olmo 2 I add a new model type here. ## Testing Done Hardware Type: NVIDIA RTX 4090 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent 98d9e19 commit 8a93398

File tree

11 files changed

+518
-6
lines changed

11 files changed

+518
-6
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ loss.backward()
262262
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
263263
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
264264
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265+
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265266
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266267
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267268
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mixtral # noqa: F401
5353
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_mllama # noqa: F401
5454
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
55+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo3 # noqa: F401
5556
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
5657
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
5758
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
@@ -118,6 +119,7 @@ def __getattr__(name: str):
118119
"apply_liger_kernel_to_mixtral",
119120
"apply_liger_kernel_to_mllama",
120121
"apply_liger_kernel_to_olmo2",
122+
"apply_liger_kernel_to_olmo3",
121123
"apply_liger_kernel_to_paligemma",
122124
"apply_liger_kernel_to_phi3",
123125
"apply_liger_kernel_to_qwen2",
@@ -194,6 +196,7 @@ def __getattr__(name: str):
194196
"apply_liger_kernel_to_mixtral",
195197
"apply_liger_kernel_to_mllama",
196198
"apply_liger_kernel_to_olmo2",
199+
"apply_liger_kernel_to_olmo3",
197200
"apply_liger_kernel_to_paligemma",
198201
"apply_liger_kernel_to_phi3",
199202
"apply_liger_kernel_to_qwen2",
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from typing import List
2+
from typing import Optional
3+
from typing import Tuple
4+
from typing import Union
5+
6+
import torch
7+
8+
from transformers.modeling_outputs import BaseModelOutputWithPast
9+
from transformers.utils.deprecation import deprecate_kwarg
10+
11+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
13+
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast
14+
15+
16+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
17+
def lce_forward(
18+
self,
19+
input_ids: torch.LongTensor = None,
20+
attention_mask: Optional[torch.Tensor] = None,
21+
position_ids: Optional[torch.LongTensor] = None,
22+
past_key_values: Optional[List[torch.FloatTensor]] = None,
23+
inputs_embeds: Optional[torch.FloatTensor] = None,
24+
labels: Optional[torch.LongTensor] = None,
25+
use_cache: Optional[bool] = None,
26+
output_attentions: Optional[bool] = None,
27+
output_hidden_states: Optional[bool] = None,
28+
return_dict: Optional[bool] = None,
29+
cache_position: Optional[torch.LongTensor] = None,
30+
logits_to_keep: Union[int, torch.Tensor] = 0,
31+
skip_logits: Optional[bool] = None,
32+
**kwargs,
33+
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
34+
r"""
35+
Args:
36+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
38+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
39+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
40+
41+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
42+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
43+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
44+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
45+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
46+
This is useful when using packed tensor format (single dimension for batch and sequence length).
47+
48+
Returns:
49+
50+
Example:
51+
52+
```python
53+
>>> from transformers import AutoTokenizer, Olmo3ForCausalLM
54+
55+
>>> model = Olmo3ForCausalLM.from_pretrained("allenai/Olmo-3-7B-Instruct")
56+
>>> tokenizer = AutoTokenizer.from_pretrained("allenai/Olmo-3-7B-Instruct")
57+
58+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
59+
>>> inputs = tokenizer(prompt, return_tensors="pt")
60+
61+
>>> # Generate
62+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
63+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
64+
'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
65+
```
66+
"""
67+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
68+
output_hidden_states = (
69+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
70+
)
71+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
72+
73+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
74+
outputs: BaseModelOutputWithPast = self.model(
75+
input_ids=input_ids,
76+
attention_mask=attention_mask,
77+
position_ids=position_ids,
78+
past_key_values=past_key_values,
79+
inputs_embeds=inputs_embeds,
80+
use_cache=use_cache,
81+
output_attentions=output_attentions,
82+
output_hidden_states=output_hidden_states,
83+
return_dict=return_dict,
84+
cache_position=cache_position,
85+
**kwargs,
86+
)
87+
88+
hidden_states = outputs[0]
89+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
90+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
91+
kept_hidden_states = hidden_states[:, slice_indices, :]
92+
93+
shift_labels = kwargs.pop("shift_labels", None)
94+
logits = None
95+
loss = None
96+
token_accuracy = None
97+
98+
if skip_logits and labels is None and shift_labels is None:
99+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
100+
101+
if skip_logits is None:
102+
# By default, if in training mode, don't materialize logits
103+
skip_logits = self.training and (labels is not None or shift_labels is not None)
104+
105+
# Compute loss
106+
if skip_logits:
107+
result = LigerForCausalLMLoss(
108+
hidden_states=kept_hidden_states,
109+
lm_head_weight=self.lm_head.weight,
110+
labels=labels,
111+
shift_labels=shift_labels,
112+
hidden_size=self.config.hidden_size,
113+
**kwargs,
114+
)
115+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
116+
117+
else:
118+
logits = self.lm_head(kept_hidden_states)
119+
if labels is not None or shift_labels is not None:
120+
loss = self.loss_function(
121+
logits=logits,
122+
labels=labels,
123+
shift_labels=shift_labels,
124+
vocab_size=self.config.vocab_size,
125+
**kwargs,
126+
)
127+
128+
if not return_dict:
129+
output = (logits,) + outputs[1:]
130+
output = ((loss,) + output) if loss is not None else output
131+
output = output + (token_accuracy,) if token_accuracy is not None else output
132+
return output
133+
134+
# Return custom output class with token_accuracy field
135+
return LigerCausalLMOutputWithPast(
136+
loss=loss,
137+
logits=logits,
138+
past_key_values=outputs.past_key_values,
139+
hidden_states=outputs.hidden_states,
140+
attentions=outputs.attentions,
141+
token_accuracy=token_accuracy,
142+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,74 @@ def apply_liger_kernel_to_olmo2(
19281928
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
19291929

19301930

1931+
def apply_liger_kernel_to_olmo3(
1932+
rope: bool = True,
1933+
cross_entropy: bool = False,
1934+
fused_linear_cross_entropy: bool = True,
1935+
rms_norm: bool = True,
1936+
swiglu: bool = True,
1937+
model: PreTrainedModel = None,
1938+
) -> None:
1939+
"""
1940+
Apply Liger kernels to replace original implementation in HuggingFace Olmo3 models.
1941+
1942+
Args:
1943+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1944+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1945+
fused_linear_cross_entropy (bool):
1946+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
1947+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1948+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1949+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1950+
swiglu (bool): Whether to apply Liger's SwiGLU to Olmo3MLP. Default is True.
1951+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1952+
loaded. Default is None.
1953+
"""
1954+
assert not (cross_entropy and fused_linear_cross_entropy), (
1955+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
1956+
)
1957+
1958+
from transformers.models.olmo3 import modeling_olmo3
1959+
from transformers.models.olmo3.modeling_olmo3 import Olmo3Model
1960+
1961+
from liger_kernel.transformers.model.olmo3 import lce_forward as olmo3_lce_forward
1962+
from liger_kernel.transformers.rms_norm import LigerRMSNormForOlmo2
1963+
1964+
# Olmo3 arch is very similar to Olmo2, so we can reuse all these components in the same way.
1965+
if rope:
1966+
modeling_olmo3.apply_rotary_pos_emb = liger_rotary_pos_emb
1967+
if rms_norm:
1968+
modeling_olmo3.Olmo3RMSNorm = LigerRMSNormForOlmo2 # same as olmo2
1969+
if swiglu:
1970+
modeling_olmo3.Olmo3MLP = LigerSwiGLUMLP
1971+
if cross_entropy:
1972+
from transformers.loss.loss_utils import nn
1973+
1974+
nn.functional.cross_entropy = liger_cross_entropy
1975+
if fused_linear_cross_entropy:
1976+
if model is not None:
1977+
model.forward = MethodType(olmo3_lce_forward, model)
1978+
else:
1979+
modeling_olmo3.Olmo3ForCausalLM.forward = olmo3_lce_forward
1980+
1981+
if model is not None:
1982+
# The model instance already exists, so we need to additionally patch the
1983+
# instance variables that reference already-instantiated modules
1984+
1985+
# get the base model from the model instance
1986+
base_model: Olmo3Model = getattr(model, model.base_model_prefix, model)
1987+
1988+
if rms_norm:
1989+
_patch_rms_norm_module(base_model.norm)
1990+
1991+
for decoder_layer in base_model.layers:
1992+
if swiglu:
1993+
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
1994+
if rms_norm:
1995+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm, in_place=False)
1996+
_patch_rms_norm_module(decoder_layer.post_feedforward_layernorm, in_place=False)
1997+
1998+
19311999
def apply_liger_kernel_to_glm4(
19322000
rope: bool = False,
19332001
cross_entropy: bool = False,
@@ -2589,7 +2657,7 @@ def apply_liger_kernel_to_hunyuan_v1_dense(
25892657
from transformers.loss.loss_utils import nn
25902658

25912659
nn.functional.cross_entropy = liger_cross_entropy
2592-
2660+
25932661
if fused_linear_cross_entropy:
25942662
if model is not None:
25952663
model.forward = MethodType(hunyuan_v1_lce_forward, model)
@@ -2695,6 +2763,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
26952763
"mistral": apply_liger_kernel_to_mistral,
26962764
"mixtral": apply_liger_kernel_to_mixtral,
26972765
"olmo2": apply_liger_kernel_to_olmo2,
2766+
"olmo3": apply_liger_kernel_to_olmo3,
26982767
"qwen2": apply_liger_kernel_to_qwen2,
26992768
"qwen3": apply_liger_kernel_to_qwen3,
27002769
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,

src/liger_kernel/transformers/swiglu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,4 @@ def __init__(self, config, layer_idx=None, is_shared_mlp=False):
9393
raise ValueError(f"Activation function {config.hidden_act} not supported.")
9494

9595
def forward(self, x):
96-
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))
96+
return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x)))

test/convergence/bf16/test_mini_models.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from liger_kernel.transformers import apply_liger_kernel_to_mixtral
4141
from liger_kernel.transformers import apply_liger_kernel_to_mllama
4242
from liger_kernel.transformers import apply_liger_kernel_to_olmo2
43+
from liger_kernel.transformers import apply_liger_kernel_to_olmo3
4344
from liger_kernel.transformers import apply_liger_kernel_to_phi3
4445
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
4546
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl
@@ -74,6 +75,7 @@
7475
from test.utils import revert_liger_kernel_to_mixtral
7576
from test.utils import revert_liger_kernel_to_mllama
7677
from test.utils import revert_liger_kernel_to_olmo2
78+
from test.utils import revert_liger_kernel_to_olmo3
7779
from test.utils import revert_liger_kernel_to_phi3
7880
from test.utils import revert_liger_kernel_to_qwen2
7981
from test.utils import revert_liger_kernel_to_qwen2_5_vl
@@ -194,6 +196,15 @@
194196
except ImportError:
195197
OLMO2_AVAILABLE = False
196198

199+
try:
200+
# OLMO3 is only available in transformers>=4.57.0
201+
from transformers.models.olmo3.configuration_olmo3 import Olmo3Config
202+
from transformers.models.olmo3.modeling_olmo3 import Olmo3ForCausalLM
203+
204+
OLMO3_AVAILABLE = True
205+
except ImportError:
206+
OLMO3_AVAILABLE = False
207+
197208
try:
198209
# Glm4 is only available in transformers>=4.51.3
199210
from transformers.models.glm4.configuration_glm4 import Glm4Config
@@ -1009,6 +1020,35 @@
10091020
),
10101021
)
10111022

1023+
if OLMO3_AVAILABLE:
1024+
MINI_MODEL_SETUPS["mini_olmo3"] = MiniModelConfig(
1025+
liger_kernel_patch_func=apply_liger_kernel_to_olmo3,
1026+
liger_kernel_patch_revert_func=revert_liger_kernel_to_olmo3,
1027+
model_class=Olmo3ForCausalLM,
1028+
mini_model_config=Olmo3Config(
1029+
bos_token_id=1, # 128000
1030+
eos_token_id=2, # 128001
1031+
pad_token_id=2,
1032+
cross_attention_layers=None,
1033+
dropout=0,
1034+
hidden_act="silu",
1035+
hidden_size=1024, # 4096
1036+
initializer_range=0.02,
1037+
intermediate_size=2048, # 14336
1038+
max_position_embeddings=4096,
1039+
num_attention_heads=8, # 32
1040+
num_hidden_layers=4, # 40
1041+
num_key_value_heads=2, # 8
1042+
rms_norm_eps=1e-5,
1043+
rope_scaling=None,
1044+
rope_theta=500_000,
1045+
tie_word_embeddings=False,
1046+
use_cache=True,
1047+
vocab_size=32000, # 128256,
1048+
attn_implementation="sdpa", # default value, pytorch native attention
1049+
),
1050+
)
1051+
10121052
if GLM4_AVAILABLE:
10131053
MINI_MODEL_SETUPS["mini_glm4"] = MiniModelConfig(
10141054
liger_kernel_patch_func=apply_liger_kernel_to_glm4,
@@ -1351,7 +1391,7 @@
13511391
liger_kernel_patch_func=apply_liger_kernel_to_hunyuan_v1_moe,
13521392
liger_kernel_patch_revert_func=revert_liger_kernel_to_hunyuan_v1_moe,
13531393
model_class=HunYuanMoEV1ForCausalLM,
1354-
mini_model_config = HunYuanMoEV1Config(
1394+
mini_model_config=HunYuanMoEV1Config(
13551395
vocab_size=32000,
13561396
hidden_size=128,
13571397
intermediate_size=512,
@@ -1751,6 +1791,25 @@ def run_mini_model(
17511791
),
17521792
],
17531793
),
1794+
pytest.param(
1795+
"mini_olmo3",
1796+
32,
1797+
1e-5,
1798+
torch.bfloat16,
1799+
1e-2,
1800+
1e-2,
1801+
1e-1,
1802+
1e-2,
1803+
1e-2,
1804+
1e-2,
1805+
marks=[
1806+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
1807+
pytest.mark.skipif(
1808+
not OLMO3_AVAILABLE,
1809+
reason="OLMO3 not available in this version of transformers",
1810+
),
1811+
],
1812+
),
17541813
pytest.param(
17551814
"mini_glm4",
17561815
32,

0 commit comments

Comments
 (0)