Skip to content

Commit 90cd00b

Browse files
vvvdwbvvvlancerts
andauthored
Add glm4.1v model support (#858)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> This PR adds support for GLM4.1V (GLM-4 Vision) models to the Liger Kernel #854 https://huggingface.co/zai-org/GLM-4.1V-9B-Thinking This model have been merged in huggingface/transformers#38431 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]>
1 parent cee2b56 commit 90cd00b

File tree

9 files changed

+732
-5
lines changed

9 files changed

+732
-5
lines changed

src/liger_kernel/transformers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
3636
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
3737
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
38+
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
3839
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
3940
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
4041
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
@@ -93,6 +94,7 @@ def __getattr__(name: str):
9394
"apply_liger_kernel_to_gemma3",
9495
"apply_liger_kernel_to_gemma3_text",
9596
"apply_liger_kernel_to_glm4",
97+
"apply_liger_kernel_to_glm4v",
9698
"apply_liger_kernel_to_granite",
9799
"apply_liger_kernel_to_llama",
98100
"apply_liger_kernel_to_llava",
@@ -156,6 +158,7 @@ def __getattr__(name: str):
156158
"apply_liger_kernel_to_gemma3",
157159
"apply_liger_kernel_to_gemma3_text",
158160
"apply_liger_kernel_to_glm4",
161+
"apply_liger_kernel_to_glm4v",
159162
"apply_liger_kernel_to_granite",
160163
"apply_liger_kernel_to_llama",
161164
"apply_liger_kernel_to_llava",
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from typing import List
2+
from typing import Optional
3+
from typing import Tuple
4+
from typing import Union
5+
6+
import torch
7+
8+
from transformers.modeling_outputs import CausalLMOutputWithPast
9+
from transformers.utils.deprecation import deprecate_kwarg
10+
11+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
12+
13+
14+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
15+
def lce_forward(
16+
self,
17+
input_ids: torch.LongTensor = None,
18+
attention_mask: Optional[torch.Tensor] = None,
19+
position_ids: Optional[torch.LongTensor] = None,
20+
past_key_values: Optional[List[torch.FloatTensor]] = None,
21+
inputs_embeds: Optional[torch.FloatTensor] = None,
22+
labels: Optional[torch.LongTensor] = None,
23+
use_cache: Optional[bool] = None,
24+
output_attentions: Optional[bool] = None,
25+
output_hidden_states: Optional[bool] = None,
26+
return_dict: Optional[bool] = None,
27+
cache_position: Optional[torch.LongTensor] = None,
28+
logits_to_keep: Union[int, torch.Tensor] = 0,
29+
skip_logits: Optional[bool] = None,
30+
**kwargs,
31+
) -> Union[Tuple, CausalLMOutputWithPast]:
32+
r"""
33+
Args:
34+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
35+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
36+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
37+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
38+
39+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
40+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
41+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
42+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
43+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
44+
This is useful when using packed tensor format (single dimension for batch and sequence length).
45+
46+
Returns:
47+
48+
Example:
49+
50+
```python
51+
>>> from PIL import Image
52+
>>> from transformers import AutoTokenizer, Glm4vForConditionalGeneration
53+
54+
>>> MODEL_PATH = "THUDM/GLM-4.1V-9B-Thinking"
55+
>>> messages = [
56+
{
57+
"role": "user",
58+
"content": [
59+
{
60+
"type": "image",
61+
"url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
62+
},
63+
{
64+
"type": "text",
65+
"text": "describe this image"
66+
}
67+
],
68+
}
69+
]
70+
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH, use_fast=True)
71+
>>> model = Glm4vForConditionalGeneration.from_pretrained(
72+
pretrained_model_name_or_path=MODEL_PATH,
73+
torch_dtype=torch.bfloat16,
74+
device_map="auto",
75+
)
76+
>>> inputs = processor.apply_chat_template(
77+
messages,
78+
tokenize=True,
79+
add_generation_prompt=True,
80+
return_dict=True,
81+
return_tensors="pt"
82+
).to(model.device)
83+
>>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
84+
output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
85+
<think>Got it, let's describe the image. First, there's a vintage car, specifically a Volkswagen Beetle
86+
```"""
87+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
88+
output_hidden_states = (
89+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
90+
)
91+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
92+
93+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
94+
outputs = self.model(
95+
input_ids=input_ids,
96+
attention_mask=attention_mask,
97+
position_ids=position_ids,
98+
past_key_values=past_key_values,
99+
inputs_embeds=inputs_embeds,
100+
use_cache=use_cache,
101+
output_attentions=output_attentions,
102+
output_hidden_states=output_hidden_states,
103+
return_dict=return_dict,
104+
cache_position=cache_position,
105+
**kwargs,
106+
)
107+
108+
hidden_states = outputs[0]
109+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
110+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
111+
kept_hidden_states = hidden_states[:, slice_indices, :]
112+
113+
shift_labels = kwargs.pop("shift_labels", None)
114+
logits = None
115+
loss = None
116+
117+
if skip_logits and labels is None and shift_labels is None:
118+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
119+
120+
if skip_logits is None:
121+
# By default, if in training mode, don't materialize logits
122+
skip_logits = self.training and (labels is not None or shift_labels is not None)
123+
124+
if skip_logits:
125+
loss = LigerForCausalLMLoss(
126+
hidden_states=kept_hidden_states,
127+
lm_head_weight=self.lm_head.weight,
128+
labels=labels,
129+
shift_labels=shift_labels,
130+
hidden_size=self.config.hidden_size,
131+
**kwargs,
132+
)
133+
134+
else:
135+
logits = self.lm_head(kept_hidden_states)
136+
if labels is not None:
137+
loss = self.loss_function(
138+
logits=logits,
139+
labels=labels,
140+
vocab_size=self.config.vocab_size,
141+
**kwargs,
142+
)
143+
144+
return CausalLMOutputWithPast(
145+
loss=loss,
146+
logits=logits,
147+
past_key_values=outputs.past_key_values,
148+
hidden_states=outputs.hidden_states,
149+
attentions=outputs.attentions,
150+
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,13 +1839,103 @@ def apply_liger_kernel_to_glm4(
18391839
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)
18401840

18411841

1842+
def apply_liger_kernel_to_glm4v(
1843+
rope: bool = False,
1844+
cross_entropy: bool = False,
1845+
fused_linear_cross_entropy: bool = True,
1846+
rms_norm: bool = True,
1847+
swiglu: bool = True,
1848+
model: PreTrainedModel = None,
1849+
) -> None:
1850+
"""
1851+
Apply Liger kernels to replace original implementation in HuggingFace GLM-4v models.
1852+
1853+
Args:
1854+
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
1855+
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
1856+
fused_linear_cross_entropy (bool):
1857+
Whether to apply Liger's fused linear cross entropy loss. Default is True.
1858+
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
1859+
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
1860+
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
1861+
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
1862+
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
1863+
loaded. Default is None.
1864+
"""
1865+
assert not (cross_entropy and fused_linear_cross_entropy), (
1866+
"cross_entropy and fused_linear_cross_entropy cannot both be True."
1867+
)
1868+
1869+
from transformers.models.glm4v import modeling_glm4v
1870+
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
1871+
from transformers.models.glm4v.modeling_glm4v import Glm4vModel
1872+
from transformers.models.glm4v.modeling_glm4v import Glm4vTextModel
1873+
from transformers.models.glm4v.modeling_glm4v import Glm4vVisionModel
1874+
1875+
from liger_kernel.transformers.model.glm4v import lce_forward as glm4v_lce_forward
1876+
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4
1877+
1878+
if rope:
1879+
raise NotImplementedError("liger_rotary_pos_emb is not available for Glm4 models.")
1880+
if rms_norm:
1881+
modeling_glm4v.Glm4vRMSNorm = LigerRMSNormForGlm4
1882+
if cross_entropy:
1883+
from transformers.loss.loss_utils import nn
1884+
1885+
nn.functional.cross_entropy = liger_cross_entropy
1886+
if fused_linear_cross_entropy:
1887+
if model is not None:
1888+
model.forward = MethodType(glm4v_lce_forward, model)
1889+
else:
1890+
modeling_glm4v.Glm4vForConditionalGeneration.forward = glm4v_lce_forward
1891+
1892+
if model is not None:
1893+
# The model instance already exists, so we need to additionally patch the
1894+
# instance variables that reference already-instantiated modules
1895+
if isinstance(model, (Glm4vForConditionalGeneration, Glm4vModel)):
1896+
# Note: language_model and visual properties can be accessed throught conditional class for BC.
1897+
# Not sure if it is subject to changes in the future.
1898+
# Reference: https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modeling_glm4v.py#L1305
1899+
text_model: Glm4vTextModel = model.language_model
1900+
vision_model: Glm4vVisionModel = model.visual
1901+
elif isinstance(model, Glm4vTextModel):
1902+
text_model: Glm4vTextModel = model
1903+
vision_model = None
1904+
else:
1905+
# Note: Currently there's no support for patching vision model only. Feel free to raise an issue if needed.
1906+
raise TypeError(
1907+
f"Unsupported glm4.1v model type. `model` must be `Glm4VLForConditionalGeneration`, `Glm4vVisionModel` or `Glm4vTextModel`. Got: {type(model)}"
1908+
)
1909+
1910+
if vision_model is not None:
1911+
for vision_block in vision_model.blocks:
1912+
if rms_norm:
1913+
_patch_rms_norm_module(vision_block.norm1)
1914+
_patch_rms_norm_module(vision_block.norm2)
1915+
if swiglu:
1916+
_patch_swiglu_module(vision_block.mlp, LigerSwiGLUMLP)
1917+
1918+
if text_model is not None:
1919+
if rms_norm:
1920+
_patch_rms_norm_module(text_model.norm)
1921+
for decoder_layer in text_model.layers:
1922+
if swiglu:
1923+
_patch_swiglu_module(decoder_layer.mlp, LigerPhi3SwiGLUMLP)
1924+
if rms_norm:
1925+
_patch_rms_norm_module(decoder_layer.input_layernorm)
1926+
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
1927+
_patch_rms_norm_module(decoder_layer.post_self_attn_layernorm)
1928+
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm)
1929+
1930+
18421931
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
18431932
MODEL_TYPE_TO_APPLY_LIGER_FN = {
18441933
"gemma": apply_liger_kernel_to_gemma,
18451934
"gemma2": apply_liger_kernel_to_gemma2,
18461935
"gemma3_text": apply_liger_kernel_to_gemma3_text,
18471936
"gemma3": apply_liger_kernel_to_gemma3,
18481937
"glm4": apply_liger_kernel_to_glm4,
1938+
"glm4v": apply_liger_kernel_to_glm4v,
18491939
"llama": apply_liger_kernel_to_llama,
18501940
"llama4_text": apply_liger_kernel_to_llama4,
18511941
"llama4": apply_liger_kernel_to_llama4,

0 commit comments

Comments
 (0)