Skip to content

Commit 720cc68

Browse files
yeshsuryaclaude
andauthored
[feat]: Add support for gpt-oss (#949)
## Summary This work adds support for gpt oss models ## Testing Done: Inference Benchmark Execution Total inference runs: 18 executions 12 successful benchmark measurements 6 pre-run warmup iterations Configurations tested: 3 scenarios with 2 runs each All scenarios passed with consistent, measurable improvements - Hardware Type: RTX A6000 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Claude <[email protected]>
1 parent 3d075fd commit 720cc68

File tree

8 files changed

+434
-1
lines changed

8 files changed

+434
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ loss.backward()
264264
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
265265
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
266266
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
267+
| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
267268
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
268269
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
269270
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |

src/liger_kernel/ops/rms_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def _block_rms_norm_backward_kernel(
351351

352352
# calculate the gradient of W
353353
if casting_mode == _CASTING_MODE_LLAMA:
354-
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
354+
# TODO(tcc): use tl.sum(..., dtype=tl.float32) once we upgrade to triton>=3.3.0
355355
dW_row += tl.sum((dY_row * (X_row * rstd_row[:, None]).to(X_dtype)).to(tl.float32), 0)
356356
else:
357357
# here X_row is already in fp32 (see previous if block)

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
4242
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
4343
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
44+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
4445
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
4546
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
4647
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
@@ -110,6 +111,7 @@ def __getattr__(name: str):
110111
"apply_liger_kernel_to_glm4",
111112
"apply_liger_kernel_to_glm4v",
112113
"apply_liger_kernel_to_glm4v_moe",
114+
"apply_liger_kernel_to_gpt_oss",
113115
"apply_liger_kernel_to_granite",
114116
"apply_liger_kernel_to_internvl",
115117
"apply_liger_kernel_to_llama",
@@ -187,6 +189,7 @@ def __getattr__(name: str):
187189
"apply_liger_kernel_to_glm4",
188190
"apply_liger_kernel_to_glm4v",
189191
"apply_liger_kernel_to_glm4v_moe",
192+
"apply_liger_kernel_to_gpt_oss",
190193
"apply_liger_kernel_to_granite",
191194
"apply_liger_kernel_to_internvl",
192195
"apply_liger_kernel_to_llama",
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from typing import List
2+
from typing import Optional
3+
from typing import Union
4+
5+
import torch
6+
7+
from transformers.modeling_outputs import MoeModelOutputWithPast
8+
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
9+
10+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
11+
from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result
12+
from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast
13+
14+
15+
def lce_forward(
16+
self,
17+
input_ids: Optional[torch.LongTensor] = None,
18+
attention_mask: Optional[torch.Tensor] = None,
19+
position_ids: Optional[torch.LongTensor] = None,
20+
past_key_values: Optional[List[torch.FloatTensor]] = None,
21+
inputs_embeds: Optional[torch.FloatTensor] = None,
22+
labels: Optional[torch.LongTensor] = None,
23+
use_cache: Optional[bool] = None,
24+
output_attentions: Optional[bool] = None,
25+
output_hidden_states: Optional[bool] = None,
26+
output_router_logits: Optional[bool] = None,
27+
cache_position: Optional[torch.LongTensor] = None,
28+
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
skip_logits: Optional[bool] = None,
30+
**kwargs,
31+
) -> LigerMoeCausalLMOutputWithPast:
32+
r"""
33+
Forward pass for causal language modeling with Mixture of Experts (MoE) architecture using Liger Kernel optimizations.
34+
35+
Args:
36+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
37+
Indices of input sequence tokens in the vocabulary. Indices can be obtained using tokenizers.
38+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
39+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
40+
- 1 for tokens that are **not masked**,
41+
- 0 for tokens that are **masked**.
42+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
43+
Indices of positions of each input sequence tokens in the position embeddings.
44+
past_key_values (`List[torch.FloatTensor]` or `Cache`, *optional*):
45+
Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up
46+
sequential decoding. See `past_key_values` input for more details.
47+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
48+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
49+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
50+
than the model's internal embedding lookup matrix.
51+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
52+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
53+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
54+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
55+
use_cache (`bool`, *optional*):
56+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
57+
(see `past_key_values`).
58+
output_attentions (`bool`, *optional*):
59+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
60+
tensors for more detail.
61+
output_hidden_states (`bool`, *optional*):
62+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
63+
more detail.
64+
output_router_logits (`bool`, *optional*):
65+
Whether or not to return the router logits of all MoE layers. See `router_logits` under returned tensors
66+
for more detail.
67+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
68+
Indices depicting the position of the input sequence tokens in the sequence.
69+
logits_to_keep (`int` or `torch.Tensor`, *optional*, defaults to 0):
70+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
71+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
72+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
73+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
74+
This is useful when using packed tensor format (single dimension for batch and sequence length).
75+
skip_logits (`bool`, *optional*):
76+
Whether to skip logit computation and directly compute loss. If `None`, defaults to `True` during training
77+
when labels are provided (to save memory), and `False` during inference.
78+
79+
Returns:
80+
`LigerMoeCausalLMOutputWithPast`: An output object containing:
81+
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
82+
Language modeling loss (for next-token prediction), including the auxiliary load balancing loss.
83+
- aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
84+
Auxiliary load balancing loss for the sparse MoE modules.
85+
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
86+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
87+
Note: logits are `None` during training when `skip_logits=True` to save memory.
88+
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed):
89+
Cached key and value projection states for faster sequential decoding.
90+
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
91+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for each layer) of shape
92+
`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer.
93+
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`):
94+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
95+
sequence_length)`. Attentions weights after the attention softmax.
96+
- router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True`):
97+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
98+
Router logits of the MoE layers, useful to compute the auxiliary loss and z_loss.
99+
- token_accuracy (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
100+
Token-level prediction accuracy.
101+
102+
Example:
103+
104+
```python
105+
>>> from transformers import AutoTokenizer, GptOssForCausalLM
106+
>>> from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
107+
108+
>>> # Apply Liger Kernel patches for optimized performance
109+
>>> apply_liger_kernel_to_gpt_oss()
110+
111+
>>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b")
112+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
113+
114+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
115+
>>> inputs = tokenizer(prompt, return_tensors="pt")
116+
117+
>>> # Inference: Forward pass returns logits
118+
>>> outputs = model(**inputs)
119+
>>> outputs.logits.shape
120+
torch.Size([1, 12, 201088])
121+
122+
>>> # Get next token prediction
123+
>>> next_token_logits = outputs.logits[:, -1, :]
124+
>>> predicted_token_id = next_token_logits.argmax(dim=-1)
125+
126+
>>> # Training: Forward pass with labels returns loss
127+
>>> labels = inputs.input_ids.clone()
128+
>>> outputs = model(**inputs, labels=labels)
129+
>>> outputs.loss
130+
tensor(2.6454)
131+
```"""
132+
133+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
134+
output_router_logits = (
135+
output_router_logits if output_router_logits is not None else self.config.output_router_logits
136+
)
137+
138+
output_hidden_states = (
139+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
140+
)
141+
142+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
143+
outputs: MoeModelOutputWithPast = self.model(
144+
input_ids=input_ids,
145+
attention_mask=attention_mask,
146+
position_ids=position_ids,
147+
past_key_values=past_key_values,
148+
inputs_embeds=inputs_embeds,
149+
use_cache=use_cache,
150+
output_attentions=output_attentions,
151+
output_hidden_states=output_hidden_states,
152+
output_router_logits=output_router_logits,
153+
cache_position=cache_position,
154+
**kwargs,
155+
)
156+
157+
hidden_states = outputs.last_hidden_state
158+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
159+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
160+
kept_hidden_states = hidden_states[:, slice_indices, :]
161+
162+
shift_labels = kwargs.pop("shift_labels", None)
163+
logits = None
164+
loss = None
165+
token_accuracy = None
166+
167+
if skip_logits is None:
168+
skip_logits = self.training and (labels is not None or shift_labels is not None)
169+
170+
if skip_logits:
171+
result = LigerForCausalLMLoss(
172+
hidden_states=kept_hidden_states,
173+
lm_head_weight=self.lm_head.weight,
174+
labels=labels,
175+
shift_labels=shift_labels,
176+
hidden_size=self.config.hidden_size,
177+
**kwargs,
178+
)
179+
loss, _, token_accuracy = unpack_cross_entropy_result(result)
180+
else: # if in inference model materialize logits
181+
logits = self.lm_head(kept_hidden_states)
182+
if labels is not None or shift_labels is not None:
183+
loss = self.loss_function(
184+
logits=logits,
185+
labels=labels,
186+
shift_labels=shift_labels,
187+
vocab_size=self.vocab_size,
188+
**kwargs,
189+
)
190+
191+
aux_loss = None
192+
if output_router_logits:
193+
aux_loss = load_balancing_loss_func(
194+
outputs.router_logits,
195+
self.num_experts,
196+
self.num_experts_per_tok,
197+
attention_mask,
198+
)
199+
if labels is not None:
200+
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
201+
202+
return LigerMoeCausalLMOutputWithPast(
203+
loss=loss,
204+
aux_loss=aux_loss,
205+
logits=logits,
206+
past_key_values=outputs.past_key_values,
207+
hidden_states=outputs.hidden_states,
208+
attentions=outputs.attentions,
209+
router_logits=outputs.router_logits,
210+
token_accuracy=token_accuracy,
211+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
2121
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
2222
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
23+
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
2324
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
2425
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
2526
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
@@ -1459,6 +1460,79 @@ def apply_liger_kernel_to_qwen3_moe(
14591460
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
14601461

14611462

1463+
def apply_liger_kernel_to_gpt_oss(
1464+
rope: bool = True,
1465+
cross_entropy: bool = False,
1466+
fused_linear_cross_entropy: bool = True,
1467+
rms_norm: bool = True,
1468+
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
1469+
model: PreTrainedModel = None,
1470+
) -> None:
1471+
"""
1472+
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
1473+
NOTE: GPT-OSS is supported in transformers >= 4.55.0
1474+
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
1475+
implementation with clamping and MXFP4 quantization.
1476+
1477+
Args:
1478+
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
1479+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1480+
fused_linear_cross_entropy (bool):
1481+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
1482+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1483+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1484+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1485+
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
1486+
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
1487+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1488+
loaded. Default is None.
1489+
"""
1490+
if version.parse(transformers.__version__) < version.parse("4.55.0"):
1491+
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
1492+
return
1493+
1494+
assert not (cross_entropy and fused_linear_cross_entropy), (
1495+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
1496+
)
1497+
1498+
from transformers.models.gpt_oss import modeling_gpt_oss
1499+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel
1500+
1501+
if rope:
1502+
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb
1503+
1504+
if rms_norm:
1505+
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm
1506+
1507+
if cross_entropy:
1508+
from transformers.loss.loss_utils import nn
1509+
1510+
nn.functional.cross_entropy = liger_cross_entropy
1511+
1512+
if fused_linear_cross_entropy:
1513+
if model is not None:
1514+
model.forward = MethodType(gpt_oss_lce_forward, model)
1515+
else:
1516+
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward
1517+
1518+
# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
1519+
# with clamping (swiglu_limit=7.0) and MXFP4 quantization
1520+
1521+
if model is not None:
1522+
# The model instance already exists, so we need to additionally patch the
1523+
# instance variables that reference already-instantiated modules
1524+
1525+
# get the base model from the model instance
1526+
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)
1527+
1528+
if rms_norm:
1529+
_patch_rms_norm_module(base_model.norm)
1530+
for decoder_layer in base_model.layers:
1531+
if rms_norm:
1532+
_patch_rms_norm_module(decoder_layer.input_layernorm)
1533+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1534+
1535+
14621536
def apply_liger_kernel_to_qwen2_vl(
14631537
rope: bool = True,
14641538
cross_entropy: bool = False,
@@ -2752,6 +2826,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
27522826
"glm4": apply_liger_kernel_to_glm4,
27532827
"glm4v": apply_liger_kernel_to_glm4v,
27542828
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
2829+
"gpt_oss": apply_liger_kernel_to_gpt_oss,
27552830
"internvl": apply_liger_kernel_to_internvl,
27562831
"llama": apply_liger_kernel_to_llama,
27572832
"llama4_text": apply_liger_kernel_to_llama4,

0 commit comments

Comments
 (0)