Skip to content

Commit 1449c19

Browse files
Manan17lancertsvaibhavjindal
authored
fix: Fix llava (#743)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Resolve a part of #723 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> Convergence Test: bf16/test_mini_models: ``` python -m pytest test/convergence/bf16/test_mini_models.py -k llava ===================================================================================== test session starts ===================================================================================== platform linux -- Python 3.10.14, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/jobuser/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.7.0, rerunfailures-15.1, anyio-4.9.0, lipy-config-base-32.9.0, lipy-fabric-36.1.5, lipy-test-9.1.34, datadir-1.6.1 collecting ... ------------------------------------------------------------------------------------- live log collection ------------------------------------------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 16 items / 15 deselected / 1 selected test/convergence/bf16/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [100%] ``` bf16/test_mini_models_multimodal: ``` python -m pytest test/convergence/bf16/test_mini_models_multimodal.py -k llava ===================================================================================== test session starts ===================================================================================== platform linux -- Python 3.10.14, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/jobuser/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.7.0, rerunfailures-15.1, anyio-4.9.0, lipy-config-base-32.9.0, lipy-fabric-36.1.5, lipy-test-9.1.34, datadir-1.6.1 collecting ... ------------------------------------------------------------------------------------- live log collection ------------------------------------------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 7 items / 6 deselected / 1 selected test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llava-32-0.0001-dtype1-0.001-0.01-0.1-0.01-0.01-0.01] PASSED ``` fp32/test_mini_models: ``` python -m pytest test/convergence/fp32/test_mini_models.py -k llava ===================================================================================== test session starts ===================================================================================== platform linux -- Python 3.10.14, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/jobuser/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.7.0, rerunfailures-15.1, anyio-4.9.0, lipy-config-base-32.9.0, lipy-fabric-36.1.5, lipy-test-9.1.34, datadir-1.6.1 collecting ... ------------------------------------------------------------------------------------- live log collection ------------------------------------------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 17 items / 16 deselected / 1 selected test/convergence/fp32/test_mini_models.py::test_mini_model[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [100%] ``` fp32/test_mini_models_multimodal: ``` python -m pytest test/convergence/fp32/test_mini_models_multimodal.py -k llava ===================================================================================== test session starts ===================================================================================== platform linux -- Python 3.10.14, pytest-8.3.5, pluggy-1.5.0 rootdir: /home/jobuser/Liger-Kernel configfile: pyproject.toml plugins: xdist-3.7.0, rerunfailures-15.1, anyio-4.9.0, lipy-config-base-32.9.0, lipy-fabric-36.1.5, lipy-test-9.1.34, datadir-1.6.1 collecting ... ------------------------------------------------------------------------------------- live log collection ------------------------------------------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 7 items / 6 deselected / 1 selected test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [100%] ``` <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Co-authored-by: Shao Tang <[email protected]> Co-authored-by: Vaibhav Jindal <[email protected]>
1 parent ca04cad commit 1449c19

File tree

2 files changed

+87
-156
lines changed

2 files changed

+87
-156
lines changed

src/liger_kernel/transformers/model/llava.py

Lines changed: 83 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
from torch.nn import CrossEntropyLoss
99
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
10-
from transformers.utils import is_torchdynamo_compiling
11-
from transformers.utils.deprecation import deprecate_kwarg
1210

1311
from liger_kernel.transformers.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyLoss
12+
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1413

1514

1615
def lce_forward_deprecated(
@@ -28,6 +27,11 @@ def lce_forward_deprecated(
2827
output_attentions: Optional[bool] = None,
2928
output_hidden_states: Optional[bool] = None,
3029
return_dict: Optional[bool] = None,
30+
cache_position: Optional[torch.LongTensor] = None,
31+
logits_to_keep: Union[int, torch.Tensor] = 0,
32+
image_sizes: torch.Tensor = None,
33+
skip_logits: Optional[bool] = None,
34+
**lm_kwargs,
3135
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
3236
r"""
3337
Args:
@@ -36,10 +40,12 @@ def lce_forward_deprecated(
3640
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
3741
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
3842
39-
num_logits_to_keep (`int`, *optional*):
40-
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
43+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
44+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
4145
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
4246
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
47+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
48+
This is useful when using packed tensor format (single dimension for batch and sequence length).
4349
4450
4551
Returns:
@@ -65,7 +71,6 @@ def lce_forward_deprecated(
6571
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
6672
"USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
6773
```"""
68-
6974
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
7075
output_hidden_states = (
7176
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -89,73 +94,24 @@ def lce_forward_deprecated(
8994
)
9095

9196
if inputs_embeds is None:
92-
# 1. Extra the input embeddings
9397
inputs_embeds = self.get_input_embeddings()(input_ids)
9498

95-
# 2. Merge text and images
96-
if pixel_values is not None and input_ids.shape[1] != 1:
97-
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
98-
# this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
99-
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
100-
101-
if vision_feature_select_strategy == "default":
102-
selected_image_feature = selected_image_feature[:, 1:]
103-
elif vision_feature_select_strategy == "full":
104-
selected_image_feature = selected_image_feature
105-
else:
106-
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
107-
108-
image_features = self.multi_modal_projector(selected_image_feature)
109-
inputs_embeds = inputs_embeds.to(image_features.dtype)
110-
inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
111-
image_features, inputs_embeds, input_ids, attention_mask, labels
112-
)
113-
114-
# In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
115-
# generation with cache
116-
elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
117-
# Retrieve the first layer to inspect the logits and mask out the hidden states
118-
# that are set to 0
119-
first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
120-
121-
# Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
122-
batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
123-
124-
# Get the target length
125-
target_length = input_ids.shape[1]
126-
past_length = first_layer_past_key_value.shape[-1]
127-
128-
extended_attention_mask = torch.ones(
129-
(attention_mask.shape[0], past_length),
130-
dtype=attention_mask.dtype,
131-
device=attention_mask.device,
132-
)
133-
134-
# Filter out only the tokens that can be un-attended, this can happen
135-
# if one uses Llava + Fused modules where the cache on the
136-
# first iteration is already big enough, or if one passes custom cache
137-
valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
138-
new_batch_index = batch_index[valid_indices]
139-
new_non_attended_tokens = non_attended_tokens[valid_indices]
140-
141-
# Zero-out the places where we don't need to attend
142-
extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
143-
144-
attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
145-
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
146-
147-
# TODO: @raushan retain only the new behavior after v4.47
148-
elif image_features is not None:
149-
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
150-
n_image_features = image_features.shape[0] * image_features.shape[1]
99+
if pixel_values is not None:
100+
image_features = self.get_image_features(
101+
pixel_values=pixel_values,
102+
vision_feature_layer=vision_feature_layer,
103+
vision_feature_select_strategy=vision_feature_select_strategy,
104+
image_sizes=image_sizes,
105+
)
151106

152-
if n_image_tokens != n_image_features:
107+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
108+
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
109+
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
110+
n_image_tokens = (input_ids == self.config.image_token_index).sum()
111+
n_image_features = image_features.shape[0] * image_features.shape[1]
153112
raise ValueError(
154113
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
155114
)
156-
special_image_mask = (
157-
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
158-
)
159115
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
160116
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
161117

@@ -168,13 +124,19 @@ def lce_forward_deprecated(
168124
output_attentions=output_attentions,
169125
output_hidden_states=output_hidden_states,
170126
return_dict=return_dict,
127+
cache_position=cache_position,
128+
logits_to_keep=logits_to_keep,
129+
**lm_kwargs,
171130
)
172131
hidden_states = outputs[0]
173132

174133
loss = None
175134
logits = None
176135

177-
if self.training and (labels is not None):
136+
# Overwrite skip_logits, since llava never materializes logits
137+
skip_logits = labels is not None
138+
139+
if skip_logits:
178140
# Shift so that tokens < n predict n
179141
if attention_mask is not None:
180142
# we use the input attention mask to shift the logits and labels, because it is 2D.
@@ -189,21 +151,34 @@ def lce_forward_deprecated(
189151
shift_labels = labels[..., 1:].contiguous()
190152

191153
lce = LigerFusedLinearCrossEntropyLoss()
192-
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
154+
loss = lce(
155+
self.language_model.lm_head.weight,
156+
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
157+
shift_labels.view(-1).to(shift_hidden_states.device),
158+
)
193159
else:
194160
logits = self.language_model.lm_head(hidden_states)
195161
if labels is not None:
196-
# Shift so that tokens < n predict n
162+
# Upcast to float if we need to compute the loss to avoid potential precision issues
163+
logits = logits.float()
164+
shift_logits = logits[..., :-1, :]
165+
shift_labels = labels[..., 1:]
197166
if attention_mask is not None:
198-
shift_attention_mask = attention_mask[..., 1:]
199-
shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
200-
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
167+
# we use the input attention mask to shift the logits and labels, because it is 2D.
168+
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
169+
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
170+
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
171+
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
201172
else:
202-
shift_logits = logits[..., :-1, :].contiguous()
203-
shift_labels = labels[..., 1:].contiguous()
173+
shift_logits = shift_logits.contiguous()
174+
shift_labels = shift_labels.contiguous()
204175
# Flatten the tokens
205176
loss_fct = CrossEntropyLoss()
206-
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device))
177+
178+
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
179+
flat_labels = shift_labels.view(-1).to(shift_logits.device)
180+
loss = loss_fct(flat_logits, flat_labels)
181+
207182
if not return_dict:
208183
# NOTE: This part has not been tested.
209184
output = outputs[1:]
@@ -215,10 +190,9 @@ def lce_forward_deprecated(
215190
past_key_values=outputs.past_key_values,
216191
hidden_states=outputs.hidden_states,
217192
attentions=outputs.attentions,
193+
image_hidden_states=image_features if pixel_values is not None else None,
218194
)
219195

220-
221-
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
222196
def lce_forward(
223197
self,
224198
input_ids: torch.LongTensor = None,
@@ -292,103 +266,59 @@ def lce_forward(
292266
else self.config.vision_feature_select_strategy
293267
)
294268

295-
if (input_ids is None) ^ (inputs_embeds is not None):
296-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
297-
298-
if pixel_values is not None and inputs_embeds is not None:
299-
raise ValueError(
300-
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
301-
)
302-
303-
if inputs_embeds is None:
304-
inputs_embeds = self.get_input_embeddings()(input_ids)
305-
306-
if pixel_values is not None:
307-
image_features = self.get_image_features(
308-
pixel_values=pixel_values,
309-
vision_feature_layer=vision_feature_layer,
310-
vision_feature_select_strategy=vision_feature_select_strategy,
311-
image_sizes=image_sizes,
312-
)
313-
314-
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
315-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
316-
if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
317-
n_image_tokens = (input_ids == self.config.image_token_index).sum()
318-
n_image_features = image_features.shape[0] * image_features.shape[1]
319-
raise ValueError(
320-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
321-
)
322-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
323-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
324-
325-
outputs = self.language_model.model(
269+
outputs = self.model(
270+
input_ids=input_ids,
271+
pixel_values=pixel_values,
326272
attention_mask=attention_mask,
327273
position_ids=position_ids,
328274
past_key_values=past_key_values,
329275
inputs_embeds=inputs_embeds,
276+
vision_feature_layer=vision_feature_layer,
277+
vision_feature_select_strategy=vision_feature_select_strategy,
330278
use_cache=use_cache,
331279
output_attentions=output_attentions,
332280
output_hidden_states=output_hidden_states,
333-
return_dict=return_dict,
281+
return_dict=True,
334282
cache_position=cache_position,
335-
logits_to_keep=logits_to_keep,
283+
image_sizes=image_sizes,
336284
**lm_kwargs,
337285
)
338286
hidden_states = outputs[0]
287+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
288+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
289+
kept_hidden_states = hidden_states[:, slice_indices, :]
339290

340-
loss = None
291+
shift_labels = lm_kwargs.pop("shift_labels", None)
341292
logits = None
293+
loss = None
342294

343-
# Overwrite skip_logits, since llava never materializes logits
344-
skip_logits = labels is not None
295+
if skip_logits and labels is None and shift_labels is None:
296+
raise ValueError("skip_logits is True, but labels and shift_labels are None")
345297

346-
if skip_logits:
347-
# Shift so that tokens < n predict n
348-
if attention_mask is not None:
349-
# we use the input attention mask to shift the logits and labels, because it is 2D.
350-
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
351-
shift_attention_mask = attention_mask[:, -(hidden_states.shape[1] - 1) :].to(hidden_states.device)
352-
shift_hidden_states = hidden_states[..., :-1, :][
353-
shift_attention_mask.to(hidden_states.device) != 0
354-
].contiguous()
355-
shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
356-
else:
357-
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
358-
shift_labels = labels[..., 1:].contiguous()
298+
if skip_logits is None:
299+
# By default, if in training mode, don't materialize logits
300+
skip_logits = self.training and (labels is not None or shift_labels is not None)
359301

360-
lce = LigerFusedLinearCrossEntropyLoss()
361-
loss = lce(
362-
self.language_model.lm_head.weight,
363-
shift_hidden_states.view(-1, shift_hidden_states.size(-1)),
364-
shift_labels.view(-1).to(shift_hidden_states.device),
302+
if skip_logits:
303+
loss = LigerForCausalLMLoss(
304+
hidden_states=kept_hidden_states,
305+
lm_head_weight=self.lm_head.weight,
306+
labels=labels,
307+
shift_labels=shift_labels,
308+
hidden_size=self.config.text_config.hidden_size,
309+
**lm_kwargs,
365310
)
311+
366312
else:
367-
logits = self.language_model.lm_head(hidden_states)
313+
logits = self.lm_head(kept_hidden_states)
368314
if labels is not None:
369-
# Upcast to float if we need to compute the loss to avoid potential precision issues
370-
logits = logits.float()
371-
shift_logits = logits[..., :-1, :]
372-
shift_labels = labels[..., 1:]
373-
if attention_mask is not None:
374-
# we use the input attention mask to shift the logits and labels, because it is 2D.
375-
# we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
376-
shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
377-
shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
378-
shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
379-
else:
380-
shift_logits = shift_logits.contiguous()
381-
shift_labels = shift_labels.contiguous()
382-
# Flatten the tokens
383-
loss_fct = CrossEntropyLoss()
315+
loss = self.loss_function(
316+
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **lm_kwargs
317+
)
384318

385-
flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
386-
flat_labels = shift_labels.view(-1).to(shift_logits.device)
387-
loss = loss_fct(flat_logits, flat_labels)
388319

389320
if not return_dict:
390-
# NOTE: This part has not been tested.
391-
output = outputs[1:]
321+
output = (logits,) + outputs[1:]
392322
return (loss,) + output if loss is not None else output
393323

394324
return LlavaCausalLMOutputWithPast(
@@ -397,5 +327,5 @@ def lce_forward(
397327
past_key_values=outputs.past_key_values,
398328
hidden_states=outputs.hidden_states,
399329
attentions=outputs.attentions,
400-
image_hidden_states=image_features if pixel_values is not None else None,
330+
image_hidden_states=outputs.image_hidden_states,
401331
)

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,14 @@ def apply_liger_kernel_to_llava(
314314
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
315315
modeling_llava.nn.CrossEntropyLoss = LigerCrossEntropyLoss
316316
if fused_linear_cross_entropy:
317-
if transformer_version >= version.parse("4.49.0"):
317+
if transformer_version >= version.parse("4.52.0"):
318318
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward
319+
elif transformer_version >= version.parse("4.49.0") and transformer_version < version.parse("4.52.0"):
320+
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
319321
else: # if version < 4.49.0
320322
logger.warning(
321-
"Support for transformers versions < 4.49.0 will soon be discontinued due to issues with incorrect legacy processing. \n Please consider upgrading to avoid potential issues. See details: https://github.com/huggingface/transformers/pull/35526"
323+
"The latest version of Liger does not support transformers < 4.49.0 for llava. Please downgrade your liger version or upgrade your transformer version."
322324
)
323-
modeling_llava.LlavaForConditionalGeneration.forward = llava_lce_forward_deprecated
324325

325326
if model is not None:
326327
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type

0 commit comments

Comments
 (0)