Skip to content

Commit 53a37a5

Browse files
Tcc0403lancerts
andauthored
fix: Fix qwen2_vl and qwen2_5_vl monkey patch (#738)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Resolve part of #729, #713 ## Details ### For monkey patch: Determine the class of instance and patch Liger Kernel to text model and vision model accordingly. Add supported model type to `MODEL_TYPE_TO_APPLY_LIGER_FN` dictionary to correctly call the corresponding functions. ### For convergence test: `video_processor` is now required for `Qwen2_Processor` and `Qwen2_5_Processor`. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <details> <summary>transformers/monkey_patch</summary> ``` ❯ python3 -m pytest test/transformers/test_monkey_patch.py -k vl ================================================= test session starts ================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel configfile: pyproject.toml plugins: rerunfailures-15.0, xdist-3.6.1 collected 30 items / 24 deselected / 6 selected test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl_for_conditional_generation ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_vl with kwargs: {} PASSED [ 16%] test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_vl with kwargs: {} PASSED [ 33%] test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_vl_text ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_vl_text with kwargs: {} PASSED [ 50%] test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_5_vl ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_5_vl with kwargs: {} PASSED [ 66%] test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_5_vl_for_conditional_generation ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_5_vl with kwargs: {} PASSED [ 83%] test/transformers/test_monkey_patch.py::test_apply_liger_kernel_to_instance_for_qwen2_5_vl_text ---------------------------------------------------- live log call ----------------------------------------------------- INFO liger_kernel.transformers.monkey_patch:monkey_patch.py:1668 Applying Liger kernels to model instance with model type: qwen2_5_vl_text with kwargs: {} PASSED ``` </details> convergence test <details> <summary>bf16/test_mini_models</summary> ``` ❯ python3 -m pytest test/convergence/bf16/test_mini_models.py -k vl ================================================= test session starts ================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel configfile: pyproject.toml plugins: rerunfailures-15.0, xdist-3.6.1 collecting ... ------------------------------------------------- live log collection -------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 16 items / 14 deselected / 2 selected test/convergence/bf16/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-0.001-0.05-1-0.1-0.01-0.01] PASSED [ 50%] test/convergence/bf16/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype8-0.001-0.05-3-0.1-0.01-0.01] PASSED [100%] ``` </details> <details> <summary>bf16/test_mini_models_multimodal</summary> ``` ❯ python3 -m pytest test/convergence/bf16/test_mini_models_multimodal.py -k qwen2 ================================================= test session starts ================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel configfile: pyproject.toml plugins: rerunfailures-15.0, xdist-3.6.1 collecting ... ------------------------------------------------- live log collection -------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 7 items / 5 deselected / 2 selected test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype0-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [ 50%] test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_5_vl-32-0.0001-dtype2-0.001-0.01-0.1-0.01-0.01-0.01] PASSED [100%] ``` </details> <details> <summary>fp32/test_mini_models</summary> ``` ❯ python3 -m pytest test/convergence/fp32/test_mini_models.py -k vl ================================================= test session starts ================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel configfile: pyproject.toml plugins: rerunfailures-15.0, xdist-3.6.1 collecting ... ------------------------------------------------- live log collection -------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 17 items / 15 deselected / 2 selected test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_vl-32-0.0001-dtype7-1e-05-0.1-1-0.1-0.005-1e-05] PASSED [ 50%] test/convergence/fp32/test_mini_models.py::test_mini_model[mini_qwen2_5_vl-32-0.0001-dtype8-1e-05-0.1-3-0.1-0.005-1e-05] PASSED [100%] ``` </details> <details> <summary>fp32/test_mini_models_multimodal</summary> ``` ❯ python3 -m pytest test/convergence/fp32/test_mini_models_multimodal.py -k vl ================================================= test session starts ================================================== platform linux -- Python 3.10.12, pytest-8.3.4, pluggy-1.5.0 rootdir: /home/tcc/Liger-Kernel configfile: pyproject.toml plugins: rerunfailures-15.0, xdist-3.6.1 collecting ... ------------------------------------------------- live log collection -------------------------------------------------- INFO datasets:config.py:54 PyTorch version 2.7.0 available. collected 7 items / 5 deselected / 2 selected test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_vl-32-0.0001-dtype0-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [ 50%] test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_qwen2_5_vl-32-0.0001-dtype2-1e-08-1e-05-0.005-1e-05-0.005-1e-05] PASSED [100%] ``` </details> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Tcc0403 <[email protected]> Co-authored-by: Shao Tang <[email protected]>
1 parent 6853d5d commit 53a37a5

File tree

10 files changed

+441
-242
lines changed

10 files changed

+441
-242
lines changed

src/liger_kernel/transformers/model/qwen2_5_vl.py

Lines changed: 29 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
import torch
77

8-
from torch.nn import CrossEntropyLoss
98
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast
9+
from transformers.utils import can_return_tuple
1010

1111
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1212

1313

14+
@can_return_tuple
1415
def lce_forward(
1516
self,
1617
input_ids: torch.LongTensor = None,
@@ -34,14 +35,22 @@ def lce_forward(
3435
**kwargs,
3536
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
3637
r"""
37-
Copy paste Qwen2_5_VL's forward but replace torch cross entropy with liger fused linear cross entropy
38-
Args:
39-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
40-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
41-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
42-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
43-
44-
Returns:
38+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
39+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
40+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
41+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
42+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
43+
The tensors corresponding to the input videos. Pixel values can be obtained using
44+
[`AutoImageProcessor`]. See [`Qwen2_5_VLImageProcessor.__call__`] for details. [`Qwen2_5_VLProcessor`] uses
45+
[`Qwen2_5_VLImageProcessor`] for processing videos.
46+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
47+
The temporal, height and width of feature shape of each image in LLM.
48+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
49+
The temporal, height and width of feature shape of each video in LLM.
50+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
51+
The rope index difference between sequence length and multimodal rope.
52+
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
53+
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
4554
4655
Example:
4756
@@ -73,78 +82,20 @@ def lce_forward(
7382
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
7483
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
7584
```"""
85+
7686
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
7787
output_hidden_states = (
7888
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
7989
)
8090
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
8191

82-
if inputs_embeds is None:
83-
inputs_embeds = self.model.embed_tokens(input_ids)
84-
if pixel_values is not None:
85-
pixel_values = pixel_values.type(self.visual.dtype)
86-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
87-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
88-
n_image_features = image_embeds.shape[0]
89-
if n_image_tokens != n_image_features:
90-
raise ValueError(
91-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
92-
)
93-
94-
mask = input_ids == self.config.image_token_id
95-
mask_unsqueezed = mask.unsqueeze(-1)
96-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
97-
image_mask = mask_expanded.to(inputs_embeds.device)
98-
99-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
100-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
101-
102-
if pixel_values_videos is not None:
103-
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
104-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
105-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
106-
n_video_features = video_embeds.shape[0]
107-
if n_video_tokens != n_video_features:
108-
raise ValueError(
109-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
110-
)
111-
112-
mask = input_ids == self.config.video_token_id
113-
mask_unsqueezed = mask.unsqueeze(-1)
114-
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
115-
video_mask = mask_expanded.to(inputs_embeds.device)
116-
117-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
118-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
119-
120-
if attention_mask is not None:
121-
attention_mask = attention_mask.to(inputs_embeds.device)
122-
123-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
124-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
125-
# calculate RoPE index once per generation in the pre-fill stage only
126-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
127-
position_ids, rope_deltas = self.get_rope_index(
128-
input_ids,
129-
image_grid_thw,
130-
video_grid_thw,
131-
second_per_grid_ts,
132-
attention_mask,
133-
)
134-
self.rope_deltas = rope_deltas
135-
# then use the prev pre-calculated rope-deltas to get the correct position ids
136-
else:
137-
batch_size, seq_length, _ = inputs_embeds.shape
138-
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
139-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
140-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
141-
if cache_position is not None: # otherwise `deltas` is an int `0`
142-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
143-
position_ids = position_ids.add(delta)
144-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
145-
14692
outputs = self.model(
147-
input_ids=None,
93+
input_ids=input_ids,
94+
pixel_values=pixel_values,
95+
pixel_values_videos=pixel_values_videos,
96+
image_grid_thw=image_grid_thw,
97+
video_grid_thw=video_grid_thw,
98+
second_per_grid_ts=second_per_grid_ts,
14899
position_ids=position_ids,
149100
attention_mask=attention_mask,
150101
past_key_values=past_key_values,
@@ -180,19 +131,10 @@ def lce_forward(
180131
)
181132
else:
182133
logits = self.lm_head(hidden_states)
134+
135+
loss = None
183136
if labels is not None:
184-
# Upcast to float if we need to compute the loss to avoid potential precision issues
185-
logits = logits.float()
186-
# Shift so that tokens < n predict n
187-
shift_logits = logits[..., :-1, :].contiguous()
188-
shift_labels = labels[..., 1:].contiguous()
189-
# Flatten the tokens
190-
loss_fct = CrossEntropyLoss()
191-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
192-
shift_labels = shift_labels.view(-1)
193-
# Enable model parallelism
194-
shift_labels = shift_labels.to(shift_logits.device)
195-
loss = loss_fct(shift_logits, shift_labels)
137+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
196138

197139
if not return_dict:
198140
output = (logits,) + outputs[1:]
@@ -204,5 +146,5 @@ def lce_forward(
204146
past_key_values=outputs.past_key_values,
205147
hidden_states=outputs.hidden_states,
206148
attentions=outputs.attentions,
207-
rope_deltas=rope_deltas,
149+
rope_deltas=outputs.rope_deltas,
208150
)

src/liger_kernel/transformers/model/qwen2_vl.py

Lines changed: 26 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55

66
import torch
77

8-
from packaging import version
9-
from torch.nn import CrossEntropyLoss
10-
from transformers import __version__ as transformers_version
118
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
9+
from transformers.utils import can_return_tuple
1210

1311
from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
1412

1513

14+
@can_return_tuple
1615
def lce_forward(
1716
self,
1817
input_ids: torch.LongTensor = None,
@@ -35,15 +34,20 @@ def lce_forward(
3534
**kwargs,
3635
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
3736
r"""
38-
Copy paste Qwen2VL's forward but replace torch cross entropy with liger fused linear cross entropy
39-
40-
Args:
41-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
42-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
43-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
44-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
45-
46-
Returns:
37+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
38+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
39+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
40+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
41+
pixel_values_videos (`torch.FloatTensor` of shape `(seq_length, num_channels * temporal_size * image_size * image_size)):
42+
The tensors corresponding to the input videos. Pixel values can be obtained using
43+
[`AutoImageProcessor`]. See [`Qwen2VLImageProcessor.__call__`] for details. [`Qwen2VLProcessor`] uses
44+
[`Qwen2VLImageProcessor`] for processing videos.
45+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
46+
The temporal, height and width of feature shape of each image in LLM.
47+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
48+
The temporal, height and width of feature shape of each video in LLM.
49+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
50+
The rope index difference between sequence length and multimodal rope.
4751
4852
Example:
4953
@@ -75,80 +79,19 @@ def lce_forward(
7579
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
7680
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
7781
```"""
82+
7883
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
7984
output_hidden_states = (
8085
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
8186
)
8287
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
8388

84-
if inputs_embeds is None:
85-
inputs_embeds = self.model.embed_tokens(input_ids)
86-
if pixel_values is not None:
87-
pixel_values = pixel_values.type(self.visual.get_dtype())
88-
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
89-
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
90-
n_image_features = image_embeds.shape[0]
91-
if n_image_tokens != n_image_features:
92-
raise ValueError(
93-
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
94-
)
95-
image_mask = (
96-
(input_ids == self.config.image_token_id)
97-
.unsqueeze(-1)
98-
.expand_as(inputs_embeds)
99-
.to(inputs_embeds.device)
100-
)
101-
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
102-
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
103-
104-
if pixel_values_videos is not None:
105-
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
106-
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
107-
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
108-
n_video_features = video_embeds.shape[0]
109-
if n_video_tokens != n_video_features:
110-
raise ValueError(
111-
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
112-
)
113-
video_mask = (
114-
(input_ids == self.config.video_token_id)
115-
.unsqueeze(-1)
116-
.expand_as(inputs_embeds)
117-
.to(inputs_embeds.device)
118-
)
119-
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
120-
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
121-
122-
if attention_mask is not None:
123-
attention_mask = attention_mask.to(inputs_embeds.device)
124-
125-
if version.parse(transformers_version) > version.parse("4.46.3"):
126-
# NOTE: this bug fix for qwen2-vl is not applied until transformers 4.47.0
127-
# https://github.com/huggingface/transformers/issues/33401
128-
# While correct, this breaks equivalence with past versions of Qwen2-VL from
129-
# transformers and leads to failed tests or users noticing differences in results.
130-
# TODO: remove above conditional when liger drops support for transformers<4.47.0
131-
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
132-
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
133-
# calculate RoPE index once per generation in the pre-fill stage only
134-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
135-
position_ids, rope_deltas = self.get_rope_index(
136-
input_ids, image_grid_thw, video_grid_thw, attention_mask
137-
)
138-
self.rope_deltas = rope_deltas
139-
# then use the prev pre-calculated rope-deltas to get the correct position ids
140-
else:
141-
batch_size, seq_length, _ = inputs_embeds.shape
142-
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
143-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
144-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
145-
if cache_position is not None: # otherwise `deltas` is an int `0`
146-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
147-
position_ids = position_ids.add(delta)
148-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
149-
15089
outputs = self.model(
151-
input_ids=None,
90+
input_ids=input_ids,
91+
pixel_values=pixel_values,
92+
pixel_values_videos=pixel_values_videos,
93+
image_grid_thw=image_grid_thw,
94+
video_grid_thw=video_grid_thw,
15295
position_ids=position_ids,
15396
attention_mask=attention_mask,
15497
past_key_values=past_key_values,
@@ -184,29 +127,16 @@ def lce_forward(
184127
)
185128
else:
186129
logits = self.lm_head(hidden_states)
130+
131+
loss = None
187132
if labels is not None:
188-
# Upcast to float if we need to compute the loss to avoid potential precision issues
189-
logits = logits.float()
190-
# Shift so that tokens < n predict n
191-
shift_logits = logits[..., :-1, :].contiguous()
192-
shift_labels = labels[..., 1:].contiguous()
193-
# Flatten the tokens
194-
loss_fct = CrossEntropyLoss()
195-
shift_logits = shift_logits.view(-1, self.config.vocab_size)
196-
shift_labels = shift_labels.view(-1)
197-
# Enable model parallelism
198-
shift_labels = shift_labels.to(shift_logits.device)
199-
loss = loss_fct(shift_logits, shift_labels)
200-
201-
if not return_dict:
202-
output = (logits,) + outputs[1:]
203-
return (loss,) + output if loss is not None else output
133+
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size)
204134

205135
return Qwen2VLCausalLMOutputWithPast(
206136
loss=loss,
207137
logits=logits,
208138
past_key_values=outputs.past_key_values,
209139
hidden_states=outputs.hidden_states,
210140
attentions=outputs.attentions,
211-
rope_deltas=rope_deltas,
141+
rope_deltas=outputs.rope_deltas,
212142
)

0 commit comments

Comments
 (0)