|
101 | 101 | } |
102 | 102 |
|
103 | 103 |
|
| 104 | +def _expand_input_ids_with_image_tokens( |
| 105 | + text_input_ids, |
| 106 | + prompt_attention_mask, |
| 107 | + max_sequence_length, |
| 108 | + image_token_index, |
| 109 | + image_emb_len, |
| 110 | + image_emb_start, |
| 111 | + image_emb_end, |
| 112 | + pad_token_id, |
| 113 | +): |
| 114 | + special_image_token_mask = text_input_ids == image_token_index |
| 115 | + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| 116 | + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) |
| 117 | + |
| 118 | + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) |
| 119 | + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 |
| 120 | + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| 121 | + |
| 122 | + expanded_input_ids = torch.full( |
| 123 | + (text_input_ids.shape[0], max_expanded_length), |
| 124 | + pad_token_id, |
| 125 | + dtype=text_input_ids.dtype, |
| 126 | + device=text_input_ids.device, |
| 127 | + ) |
| 128 | + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] |
| 129 | + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index |
| 130 | + |
| 131 | + expanded_attention_mask = torch.zeros( |
| 132 | + (text_input_ids.shape[0], max_expanded_length), |
| 133 | + dtype=prompt_attention_mask.dtype, |
| 134 | + device=prompt_attention_mask.device, |
| 135 | + ) |
| 136 | + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) |
| 137 | + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 |
| 138 | + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) |
| 139 | + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) |
| 140 | + |
| 141 | + return { |
| 142 | + "input_ids": expanded_input_ids, |
| 143 | + "attention_mask": expanded_attention_mask, |
| 144 | + "position_ids": position_ids, |
| 145 | + } |
| 146 | + |
| 147 | + |
104 | 148 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
105 | 149 | def retrieve_timesteps( |
106 | 150 | scheduler, |
@@ -259,6 +303,12 @@ def _get_llama_prompt_embeds( |
259 | 303 | prompt = [prompt_template["template"].format(p) for p in prompt] |
260 | 304 |
|
261 | 305 | crop_start = prompt_template.get("crop_start", None) |
| 306 | + |
| 307 | + image_emb_len = prompt_template.get("image_emb_len", 576) |
| 308 | + image_emb_start = prompt_template.get("image_emb_start", 5) |
| 309 | + image_emb_end = prompt_template.get("image_emb_end", 581) |
| 310 | + double_return_token_id = prompt_template.get("double_return_token_id", 271) |
| 311 | + |
262 | 312 | if crop_start is None: |
263 | 313 | prompt_template_input = self.tokenizer( |
264 | 314 | prompt_template["template"], |
@@ -288,69 +338,25 @@ def _get_llama_prompt_embeds( |
288 | 338 |
|
289 | 339 | image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) |
290 | 340 |
|
291 | | - _, _, image_height, image_width = image_embeds.shape |
292 | | - patch_size = self.text_encoder.config.vision_config.patch_size |
293 | | - num_image_tokens = (image_height // patch_size) * (image_width // patch_size) |
294 | | - if self.text_encoder.config.vision_config.vision_feature_select_strategy == "default": |
295 | | - num_image_tokens -= 1 |
296 | | - |
297 | 341 | image_token_index = self.text_encoder.config.image_token_index |
298 | 342 | pad_token_id = self.text_encoder.config.pad_token_id |
299 | | - batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) |
300 | | - |
301 | | - special_image_token_mask = text_input_ids == image_token_index |
302 | | - num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
303 | | - |
304 | | - max_expanded_length = max_sequence_length + ( |
305 | | - num_special_image_tokens.max() * (prompt_template["image_emb_len"] - 1) |
306 | | - ) |
307 | | - new_token_positions = ( |
308 | | - torch.cumsum((special_image_token_mask * (prompt_template["image_emb_len"] - 1) + 1), -1) - 1 |
309 | | - ) |
310 | | - nb_image_pad = max_expanded_length - 1 - new_token_positions[:, -1] |
311 | | - if left_padding: |
312 | | - new_token_positions += nb_image_pad[:, None] |
313 | | - |
314 | | - text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
315 | | - |
316 | | - expanded_input_ids = torch.full( |
317 | | - (batch_size, max_expanded_length), pad_token_id, dtype=text_input_ids.dtype, device=device |
318 | | - ) |
319 | | - expanded_attention_mask = torch.ones( |
320 | | - (batch_size, max_expanded_length), dtype=prompt_attention_mask.dtype, device=device |
| 343 | + expanded_inputs = _expand_input_ids_with_image_tokens( |
| 344 | + text_input_ids, |
| 345 | + prompt_attention_mask, |
| 346 | + max_sequence_length, |
| 347 | + image_token_index, |
| 348 | + image_emb_len, |
| 349 | + image_emb_start, |
| 350 | + image_emb_end, |
| 351 | + pad_token_id, |
321 | 352 | ) |
322 | | - |
323 | | - expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] |
324 | | - expanded_inputs_ids[batch_indices, prompt_template["image_emb_start"] : prompt_template["image_emb_end"]] = ( |
325 | | - image_token_index |
326 | | - ) |
327 | | - |
328 | | - inputs = self.llava_processor( |
329 | | - text=prompt, |
330 | | - images=image, |
331 | | - # max_length=max_sequence_length, |
332 | | - padding="max_length", |
333 | | - truncation=True, |
334 | | - return_length=False, |
335 | | - return_overflowing_tokens=False, |
336 | | - return_attention_mask=True, |
337 | | - return_tensors="pt", |
338 | | - ).to(device) |
339 | | - |
340 | | - text_input_ids = inputs["input_ids"] |
341 | | - prompt_attention_mask = inputs["attention_mask"] |
342 | | - |
343 | 353 | prompt_embeds = self.text_encoder( |
344 | | - **inputs, |
| 354 | + **expanded_inputs, |
| 355 | + pixel_value=image_embeds, |
345 | 356 | output_hidden_states=True, |
346 | 357 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] |
347 | 358 | prompt_embeds = prompt_embeds.to(dtype=dtype) |
348 | 359 |
|
349 | | - image_emb_len = prompt_template.get("image_emb_len", 576) |
350 | | - image_emb_start = prompt_template.get("image_emb_start", 5) |
351 | | - image_emb_end = prompt_template.get("image_emb_end", 581) |
352 | | - double_return_token_id = prompt_template.get("double_return_token_id", 271) |
353 | | - |
354 | 360 | if crop_start is not None and crop_start > 0: |
355 | 361 | text_crop_start = crop_start - 1 + image_emb_len |
356 | 362 | batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) |
|
0 commit comments