Skip to content

Commit de32d30

Browse files
0russwest0Jintao-Huang
authored andcommitted
Fix: Correct training hang for Keye-VL on DeepSpeed with mixed data (#4889)
1 parent 99d2eeb commit de32d30

File tree

1 file changed

+191
-2
lines changed

1 file changed

+191
-2
lines changed

swift/llm/template/template/kwai.py

Lines changed: 191 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from dataclasses import dataclass, field
44
from typing import Any, Dict, List, Literal
55

6+
import numpy as np
67
import torch
7-
from transformers.dynamic_module_utils import get_class_from_dynamic_module
88

99
from swift.llm import to_device
1010
from swift.utils import is_deepspeed_enabled
@@ -13,7 +13,6 @@
1313
from ..register import register_template
1414
from ..template_inputs import StdTemplateInputs
1515
from ..utils import Context, Word, findall
16-
from .qwen import Qwen2VLTemplate
1716
from .utils import ChatmlTemplateMeta
1817

1918

@@ -89,6 +88,196 @@ def _get_new_tokens(i):
8988
encoded['labels'] = labels
9089
return encoded
9190

91+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
92+
if not self.is_training:
93+
return inputs
94+
input_ids = inputs['input_ids']
95+
pixel_values = inputs.get('pixel_values')
96+
pixel_values_videos = inputs.get('pixel_values_videos')
97+
image_grid_thw = inputs.get('image_grid_thw')
98+
video_grid_thw = inputs.get('video_grid_thw')
99+
100+
base_model = self.get_base_model(model)
101+
if hasattr(base_model.model, 'embed_tokens'):
102+
inputs_embeds = base_model.model.embed_tokens(input_ids)
103+
else:
104+
inputs_embeds = base_model.model.language_model.embed_tokens(input_ids)
105+
106+
# Get dtype from visual model, adapting for KeyeVL model structure
107+
if hasattr(model.visual, 'get_dtype'):
108+
dtype = model.visual.get_dtype()
109+
else:
110+
dtype = model.visual.dtype
111+
112+
if pixel_values is None and pixel_values_videos is None: # plain-text
113+
if is_deepspeed_enabled():
114+
from PIL import Image
115+
images = [Image.new('RGB', (32, 32), (0, 0, 0))]
116+
media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt')
117+
device = input_ids.device
118+
media_inputs = to_device(media_inputs, device)
119+
pixel_values = media_inputs['pixel_values'].type(dtype)
120+
# Convert to 5D format for KeyeVL: [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14]
121+
pixel_values = pixel_values.unsqueeze(0)
122+
123+
# KeyeVL requires position_ids when pixel_values is 5D
124+
num_patches = pixel_values.shape[1]
125+
position_ids = torch.arange(num_patches, device=device)
126+
127+
# Create dummy grid that works with mlp_AR
128+
# Assuming merge_size is 2, we need h and w divisible by merge_size
129+
merge_size = getattr(self.processor.image_processor, 'merge_size', 2)
130+
grid_size = int(np.sqrt(num_patches))
131+
132+
# Adjust grid_size to be divisible by merge_size
133+
if grid_size % merge_size != 0:
134+
grid_size = ((grid_size + merge_size - 1) // merge_size) * merge_size
135+
136+
# For dummy case, use square layout that's compatible with mlp_AR
137+
dummy_grid_hw = [(1, grid_size, grid_size)]
138+
sample_indices = torch.zeros(num_patches, dtype=torch.int64, device=device)
139+
cu_seqlens = torch.tensor([0, num_patches], dtype=torch.int32, device=device)
140+
141+
vision_outputs = model.visual(
142+
pixel_values=pixel_values,
143+
image_grid_thw=dummy_grid_hw,
144+
position_ids=position_ids,
145+
vision_return_embed_list=True,
146+
interpolate_pos_encoding=True,
147+
sample_indices=sample_indices,
148+
cu_seqlens=cu_seqlens,
149+
return_pooler_output=False,
150+
use_rope=True,
151+
window_size=-1,
152+
)
153+
image_embeds = vision_outputs.last_hidden_state
154+
# Process through projector like in normal cases
155+
image_embeds = model.mlp_AR(image_embeds, dummy_grid_hw)
156+
# Concatenate all embeddings
157+
image_embeds = torch.cat(image_embeds, dim=0)
158+
inputs_embeds += image_embeds.mean() * 0.
159+
else:
160+
if pixel_values is not None:
161+
pixel_values = pixel_values.type(dtype)
162+
# KeyeVL expects 5D input: (batch_size, sequence_len, channel, height, width)
163+
# where sequence_len is the total number of patches from all images
164+
pixel_values = pixel_values.unsqueeze(0) # [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14]
165+
166+
if image_grid_thw is not None:
167+
image_grid_hws = []
168+
for thw in image_grid_thw:
169+
if isinstance(thw, torch.Tensor):
170+
thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
171+
else:
172+
thw_tuple = tuple(thw)
173+
image_grid_hws.append(thw_tuple)
174+
175+
# Prepare position_ids and other parameters for KeyeVL
176+
siglip_position_ids = []
177+
sample_indices = []
178+
cu_seqlens = [0]
179+
180+
for idx, thw_tuple in enumerate(image_grid_hws):
181+
numel = np.prod(thw_tuple)
182+
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
183+
siglip_position_ids.append(image_position_ids)
184+
sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64))
185+
cu_seqlens.append(cu_seqlens[-1] + numel)
186+
187+
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values.device)
188+
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device)
189+
sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device)
190+
191+
# Call KeyeVL visual model
192+
vision_outputs = model.visual(
193+
pixel_values=pixel_values,
194+
image_grid_thw=image_grid_hws,
195+
position_ids=siglip_position_ids,
196+
vision_return_embed_list=True,
197+
interpolate_pos_encoding=True,
198+
sample_indices=sample_indices,
199+
cu_seqlens=cu_seqlens,
200+
return_pooler_output=False,
201+
use_rope=True,
202+
window_size=-1,
203+
)
204+
image_embeds = vision_outputs.last_hidden_state
205+
206+
# Process through projector
207+
image_embeds = model.mlp_AR(image_embeds, image_grid_thw)
208+
# Concatenate all image embeddings
209+
image_embeds = torch.cat(image_embeds, dim=0)
210+
else:
211+
# Fallback for case without grid info
212+
num_patches = pixel_values.shape[1]
213+
position_ids = torch.arange(num_patches, device=pixel_values.device)
214+
vision_outputs = model.visual(pixel_values=pixel_values, position_ids=position_ids)
215+
image_embeds = vision_outputs.last_hidden_state.reshape(-1,
216+
vision_outputs.last_hidden_state.shape[-1])
217+
218+
image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
219+
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
220+
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
221+
222+
if pixel_values_videos is not None:
223+
pixel_values_videos = pixel_values_videos.type(dtype)
224+
# Same processing for videos: convert to 5D format
225+
pixel_values_videos = pixel_values_videos.unsqueeze(
226+
0) # [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14]
227+
228+
if video_grid_thw is not None:
229+
video_grid_hws = []
230+
for thw in video_grid_thw:
231+
if isinstance(thw, torch.Tensor):
232+
thw_tuple = tuple(thw.detach().cpu().numpy().tolist())
233+
else:
234+
thw_tuple = tuple(thw)
235+
video_grid_hws.append(thw_tuple)
236+
237+
siglip_position_ids = []
238+
sample_indices = []
239+
cu_seqlens = [0]
240+
241+
for idx, thw_tuple in enumerate(video_grid_hws):
242+
numel = np.prod(thw_tuple)
243+
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
244+
siglip_position_ids.append(video_position_ids)
245+
sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64))
246+
cu_seqlens.append(cu_seqlens[-1] + numel)
247+
248+
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values_videos.device)
249+
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values_videos.device)
250+
sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values_videos.device)
251+
252+
vision_outputs = model.visual(
253+
pixel_values=pixel_values_videos,
254+
image_grid_thw=video_grid_hws,
255+
position_ids=siglip_position_ids,
256+
vision_return_embed_list=True,
257+
interpolate_pos_encoding=True,
258+
sample_indices=sample_indices,
259+
cu_seqlens=cu_seqlens,
260+
return_pooler_output=False,
261+
use_rope=True,
262+
window_size=-1,
263+
)
264+
video_embeds = vision_outputs.last_hidden_state
265+
video_embeds = model.mlp_AR(video_embeds, video_grid_thw)
266+
video_embeds = torch.cat(video_embeds, dim=0)
267+
else:
268+
# Fallback for case without grid info
269+
num_patches = pixel_values_videos.shape[1]
270+
position_ids = torch.arange(num_patches, device=pixel_values_videos.device)
271+
vision_outputs = model.visual(pixel_values=pixel_values_videos, position_ids=position_ids)
272+
video_embeds = vision_outputs.last_hidden_state.reshape(-1,
273+
vision_outputs.last_hidden_state.shape[-1])
274+
275+
video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
276+
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
277+
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
278+
279+
return {'inputs_embeds': inputs_embeds}
280+
92281
def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
93282
res = super()._data_collator_mm_data(batch)
94283
second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')

0 commit comments

Comments
 (0)