|
3 | 3 | from dataclasses import dataclass, field
|
4 | 4 | from typing import Any, Dict, List, Literal
|
5 | 5 |
|
| 6 | +import numpy as np |
6 | 7 | import torch
|
7 |
| -from transformers.dynamic_module_utils import get_class_from_dynamic_module |
8 | 8 |
|
9 | 9 | from swift.llm import to_device
|
10 | 10 | from swift.utils import is_deepspeed_enabled
|
|
13 | 13 | from ..register import register_template
|
14 | 14 | from ..template_inputs import StdTemplateInputs
|
15 | 15 | from ..utils import Context, Word, findall
|
16 |
| -from .qwen import Qwen2VLTemplate |
17 | 16 | from .utils import ChatmlTemplateMeta
|
18 | 17 |
|
19 | 18 |
|
@@ -89,6 +88,196 @@ def _get_new_tokens(i):
|
89 | 88 | encoded['labels'] = labels
|
90 | 89 | return encoded
|
91 | 90 |
|
| 91 | + def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]: |
| 92 | + if not self.is_training: |
| 93 | + return inputs |
| 94 | + input_ids = inputs['input_ids'] |
| 95 | + pixel_values = inputs.get('pixel_values') |
| 96 | + pixel_values_videos = inputs.get('pixel_values_videos') |
| 97 | + image_grid_thw = inputs.get('image_grid_thw') |
| 98 | + video_grid_thw = inputs.get('video_grid_thw') |
| 99 | + |
| 100 | + base_model = self.get_base_model(model) |
| 101 | + if hasattr(base_model.model, 'embed_tokens'): |
| 102 | + inputs_embeds = base_model.model.embed_tokens(input_ids) |
| 103 | + else: |
| 104 | + inputs_embeds = base_model.model.language_model.embed_tokens(input_ids) |
| 105 | + |
| 106 | + # Get dtype from visual model, adapting for KeyeVL model structure |
| 107 | + if hasattr(model.visual, 'get_dtype'): |
| 108 | + dtype = model.visual.get_dtype() |
| 109 | + else: |
| 110 | + dtype = model.visual.dtype |
| 111 | + |
| 112 | + if pixel_values is None and pixel_values_videos is None: # plain-text |
| 113 | + if is_deepspeed_enabled(): |
| 114 | + from PIL import Image |
| 115 | + images = [Image.new('RGB', (32, 32), (0, 0, 0))] |
| 116 | + media_inputs = self.processor.image_processor(images=images, videos=None, return_tensors='pt') |
| 117 | + device = input_ids.device |
| 118 | + media_inputs = to_device(media_inputs, device) |
| 119 | + pixel_values = media_inputs['pixel_values'].type(dtype) |
| 120 | + # Convert to 5D format for KeyeVL: [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14] |
| 121 | + pixel_values = pixel_values.unsqueeze(0) |
| 122 | + |
| 123 | + # KeyeVL requires position_ids when pixel_values is 5D |
| 124 | + num_patches = pixel_values.shape[1] |
| 125 | + position_ids = torch.arange(num_patches, device=device) |
| 126 | + |
| 127 | + # Create dummy grid that works with mlp_AR |
| 128 | + # Assuming merge_size is 2, we need h and w divisible by merge_size |
| 129 | + merge_size = getattr(self.processor.image_processor, 'merge_size', 2) |
| 130 | + grid_size = int(np.sqrt(num_patches)) |
| 131 | + |
| 132 | + # Adjust grid_size to be divisible by merge_size |
| 133 | + if grid_size % merge_size != 0: |
| 134 | + grid_size = ((grid_size + merge_size - 1) // merge_size) * merge_size |
| 135 | + |
| 136 | + # For dummy case, use square layout that's compatible with mlp_AR |
| 137 | + dummy_grid_hw = [(1, grid_size, grid_size)] |
| 138 | + sample_indices = torch.zeros(num_patches, dtype=torch.int64, device=device) |
| 139 | + cu_seqlens = torch.tensor([0, num_patches], dtype=torch.int32, device=device) |
| 140 | + |
| 141 | + vision_outputs = model.visual( |
| 142 | + pixel_values=pixel_values, |
| 143 | + image_grid_thw=dummy_grid_hw, |
| 144 | + position_ids=position_ids, |
| 145 | + vision_return_embed_list=True, |
| 146 | + interpolate_pos_encoding=True, |
| 147 | + sample_indices=sample_indices, |
| 148 | + cu_seqlens=cu_seqlens, |
| 149 | + return_pooler_output=False, |
| 150 | + use_rope=True, |
| 151 | + window_size=-1, |
| 152 | + ) |
| 153 | + image_embeds = vision_outputs.last_hidden_state |
| 154 | + # Process through projector like in normal cases |
| 155 | + image_embeds = model.mlp_AR(image_embeds, dummy_grid_hw) |
| 156 | + # Concatenate all embeddings |
| 157 | + image_embeds = torch.cat(image_embeds, dim=0) |
| 158 | + inputs_embeds += image_embeds.mean() * 0. |
| 159 | + else: |
| 160 | + if pixel_values is not None: |
| 161 | + pixel_values = pixel_values.type(dtype) |
| 162 | + # KeyeVL expects 5D input: (batch_size, sequence_len, channel, height, width) |
| 163 | + # where sequence_len is the total number of patches from all images |
| 164 | + pixel_values = pixel_values.unsqueeze(0) # [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14] |
| 165 | + |
| 166 | + if image_grid_thw is not None: |
| 167 | + image_grid_hws = [] |
| 168 | + for thw in image_grid_thw: |
| 169 | + if isinstance(thw, torch.Tensor): |
| 170 | + thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) |
| 171 | + else: |
| 172 | + thw_tuple = tuple(thw) |
| 173 | + image_grid_hws.append(thw_tuple) |
| 174 | + |
| 175 | + # Prepare position_ids and other parameters for KeyeVL |
| 176 | + siglip_position_ids = [] |
| 177 | + sample_indices = [] |
| 178 | + cu_seqlens = [0] |
| 179 | + |
| 180 | + for idx, thw_tuple in enumerate(image_grid_hws): |
| 181 | + numel = np.prod(thw_tuple) |
| 182 | + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) |
| 183 | + siglip_position_ids.append(image_position_ids) |
| 184 | + sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) |
| 185 | + cu_seqlens.append(cu_seqlens[-1] + numel) |
| 186 | + |
| 187 | + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values.device) |
| 188 | + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values.device) |
| 189 | + sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values.device) |
| 190 | + |
| 191 | + # Call KeyeVL visual model |
| 192 | + vision_outputs = model.visual( |
| 193 | + pixel_values=pixel_values, |
| 194 | + image_grid_thw=image_grid_hws, |
| 195 | + position_ids=siglip_position_ids, |
| 196 | + vision_return_embed_list=True, |
| 197 | + interpolate_pos_encoding=True, |
| 198 | + sample_indices=sample_indices, |
| 199 | + cu_seqlens=cu_seqlens, |
| 200 | + return_pooler_output=False, |
| 201 | + use_rope=True, |
| 202 | + window_size=-1, |
| 203 | + ) |
| 204 | + image_embeds = vision_outputs.last_hidden_state |
| 205 | + |
| 206 | + # Process through projector |
| 207 | + image_embeds = model.mlp_AR(image_embeds, image_grid_thw) |
| 208 | + # Concatenate all image embeddings |
| 209 | + image_embeds = torch.cat(image_embeds, dim=0) |
| 210 | + else: |
| 211 | + # Fallback for case without grid info |
| 212 | + num_patches = pixel_values.shape[1] |
| 213 | + position_ids = torch.arange(num_patches, device=pixel_values.device) |
| 214 | + vision_outputs = model.visual(pixel_values=pixel_values, position_ids=position_ids) |
| 215 | + image_embeds = vision_outputs.last_hidden_state.reshape(-1, |
| 216 | + vision_outputs.last_hidden_state.shape[-1]) |
| 217 | + |
| 218 | + image_mask = (input_ids == model.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) |
| 219 | + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 220 | + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| 221 | + |
| 222 | + if pixel_values_videos is not None: |
| 223 | + pixel_values_videos = pixel_values_videos.type(dtype) |
| 224 | + # Same processing for videos: convert to 5D format |
| 225 | + pixel_values_videos = pixel_values_videos.unsqueeze( |
| 226 | + 0) # [num_patches, 3, 14, 14] -> [1, num_patches, 3, 14, 14] |
| 227 | + |
| 228 | + if video_grid_thw is not None: |
| 229 | + video_grid_hws = [] |
| 230 | + for thw in video_grid_thw: |
| 231 | + if isinstance(thw, torch.Tensor): |
| 232 | + thw_tuple = tuple(thw.detach().cpu().numpy().tolist()) |
| 233 | + else: |
| 234 | + thw_tuple = tuple(thw) |
| 235 | + video_grid_hws.append(thw_tuple) |
| 236 | + |
| 237 | + siglip_position_ids = [] |
| 238 | + sample_indices = [] |
| 239 | + cu_seqlens = [0] |
| 240 | + |
| 241 | + for idx, thw_tuple in enumerate(video_grid_hws): |
| 242 | + numel = np.prod(thw_tuple) |
| 243 | + video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) |
| 244 | + siglip_position_ids.append(video_position_ids) |
| 245 | + sample_indices.append(torch.full((numel, ), idx, dtype=torch.int64)) |
| 246 | + cu_seqlens.append(cu_seqlens[-1] + numel) |
| 247 | + |
| 248 | + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(pixel_values_videos.device) |
| 249 | + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(pixel_values_videos.device) |
| 250 | + sample_indices = torch.concat(sample_indices, dim=0).to(pixel_values_videos.device) |
| 251 | + |
| 252 | + vision_outputs = model.visual( |
| 253 | + pixel_values=pixel_values_videos, |
| 254 | + image_grid_thw=video_grid_hws, |
| 255 | + position_ids=siglip_position_ids, |
| 256 | + vision_return_embed_list=True, |
| 257 | + interpolate_pos_encoding=True, |
| 258 | + sample_indices=sample_indices, |
| 259 | + cu_seqlens=cu_seqlens, |
| 260 | + return_pooler_output=False, |
| 261 | + use_rope=True, |
| 262 | + window_size=-1, |
| 263 | + ) |
| 264 | + video_embeds = vision_outputs.last_hidden_state |
| 265 | + video_embeds = model.mlp_AR(video_embeds, video_grid_thw) |
| 266 | + video_embeds = torch.cat(video_embeds, dim=0) |
| 267 | + else: |
| 268 | + # Fallback for case without grid info |
| 269 | + num_patches = pixel_values_videos.shape[1] |
| 270 | + position_ids = torch.arange(num_patches, device=pixel_values_videos.device) |
| 271 | + vision_outputs = model.visual(pixel_values=pixel_values_videos, position_ids=position_ids) |
| 272 | + video_embeds = vision_outputs.last_hidden_state.reshape(-1, |
| 273 | + vision_outputs.last_hidden_state.shape[-1]) |
| 274 | + |
| 275 | + video_mask = (input_ids == model.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) |
| 276 | + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| 277 | + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
| 278 | + |
| 279 | + return {'inputs_embeds': inputs_embeds} |
| 280 | + |
92 | 281 | def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
|
93 | 282 | res = super()._data_collator_mm_data(batch)
|
94 | 283 | second_per_grid_ts = self.gather_list(batch, 'second_per_grid_ts')
|
|
0 commit comments