|
100 | 100 | } |
101 | 101 |
|
102 | 102 |
|
| 103 | +def _merge_input_ids_with_image_features( |
| 104 | + image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id |
| 105 | +): |
| 106 | + num_images, num_image_patches, embed_dim = image_features.shape |
| 107 | + batch_size, sequence_length = input_ids.shape |
| 108 | + special_image_token_mask = input_ids == image_token_index |
| 109 | + # 1. Create a mask to know where special image tokens are |
| 110 | + special_image_token_mask = input_ids == image_token_index |
| 111 | + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) |
| 112 | + batch_indices, non_image_indices = torch.where(input_ids != image_token_index) |
| 113 | + max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length |
| 114 | + batch_indices, non_image_indices = torch.where(input_ids != image_token_index) |
| 115 | + |
| 116 | + # 2. Compute the positions where text should be written |
| 117 | + # Calculate new positions for text tokens in merged image-text sequence. |
| 118 | + # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. |
| 119 | + # `torch.cumsum` computes how each image token shifts subsequent text token positions. |
| 120 | + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. |
| 121 | + new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 |
| 122 | + nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] |
| 123 | + if left_padding: |
| 124 | + new_token_positions += nb_image_pad[:, None] # offset for left padding |
| 125 | + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] |
| 126 | + |
| 127 | + # 3. Create the full embedding, already padded to the maximum position |
| 128 | + final_embedding = torch.zeros( |
| 129 | + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device |
| 130 | + ) |
| 131 | + final_attention_mask = torch.zeros( |
| 132 | + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device |
| 133 | + ) |
| 134 | + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually |
| 135 | + # set the corresponding tensors into their correct target device. |
| 136 | + target_device = inputs_embeds.device |
| 137 | + batch_indices, non_image_indices, text_to_overwrite = ( |
| 138 | + batch_indices.to(target_device), |
| 139 | + non_image_indices.to(target_device), |
| 140 | + text_to_overwrite.to(target_device), |
| 141 | + ) |
| 142 | + attention_mask = attention_mask.to(target_device) |
| 143 | + |
| 144 | + # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] |
| 145 | + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features |
| 146 | + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] |
| 147 | + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] |
| 148 | + |
| 149 | + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) |
| 150 | + image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device) |
| 151 | + image_to_overwrite[batch_indices, text_to_overwrite] = False |
| 152 | + if left_padding: |
| 153 | + image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) |
| 154 | + else: |
| 155 | + mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1 |
| 156 | + padding_mask = mask <= new_token_positions[:, -1:].to(target_device) |
| 157 | + image_to_overwrite &= padding_mask |
| 158 | + |
| 159 | + if image_to_overwrite.sum() != image_features.shape[:-1].numel(): |
| 160 | + raise ValueError( |
| 161 | + f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" |
| 162 | + f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." |
| 163 | + ) |
| 164 | + |
| 165 | + batch_indices, pad_indices = torch.where(input_ids == pad_token_id) |
| 166 | + final_attention_mask |= image_to_overwrite |
| 167 | + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) |
| 168 | + |
| 169 | + # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. |
| 170 | + batch_indices, pad_indices = torch.where(input_ids == pad_token_id) |
| 171 | + indices_to_mask = new_token_positions[batch_indices, pad_indices] |
| 172 | + |
| 173 | + final_embedding[batch_indices, indices_to_mask] = 0 |
| 174 | + |
| 175 | + return final_embedding, final_attention_mask, position_ids |
| 176 | + |
| 177 | + |
| 178 | +def _text_encoder_custom_forward(): |
| 179 | + return |
| 180 | + |
| 181 | + |
103 | 182 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps |
104 | 183 | def retrieve_timesteps( |
105 | 184 | scheduler, |
@@ -279,13 +358,30 @@ def _get_llama_prompt_embeds( |
279 | 358 | prompt_attention_mask = text_inputs.attention_mask.to(device=device) |
280 | 359 |
|
281 | 360 | image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) |
| 361 | + input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids) |
| 362 | + |
| 363 | + inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features( |
| 364 | + image_embeds, |
| 365 | + input_embeds, |
| 366 | + text_input_ids, |
| 367 | + prompt_attention_mask, |
| 368 | + self.text_encoder.config.image_token_index, |
| 369 | + self.text_encoder.pad_token_id, |
| 370 | + ) |
282 | 371 |
|
| 372 | + prompt_embeds = self.text_encoder.language_model( |
| 373 | + attention_mask=attention_mask, |
| 374 | + position_ids=position_ids, |
| 375 | + inputs_embeds=inputs_embeds, |
| 376 | + ).hidden_states[-(num_hidden_layers_to_skip + 1)] |
| 377 | + """ |
283 | 378 | prompt_embeds = self.text_encoder( |
284 | 379 | input_ids=text_input_ids, |
285 | 380 | attention_mask=prompt_attention_mask, |
286 | 381 | pixel_values=image_embeds, |
287 | 382 | output_hidden_states=True, |
288 | 383 | ).hidden_states[-(num_hidden_layers_to_skip + 1)] |
| 384 | + """ |
289 | 385 | prompt_embeds = prompt_embeds.to(dtype=dtype) |
290 | 386 |
|
291 | 387 | image_emb_len = prompt_template.get("image_emb_len", 576) |
|
0 commit comments