| 
24 | 24 |     CLIPTokenizer,  | 
25 | 25 |     LlamaTokenizerFast,  | 
26 | 26 |     LlavaForConditionalGeneration,  | 
 | 27 | +    LlavaProcessor,  | 
27 | 28 | )  | 
28 | 29 | 
 
  | 
29 | 30 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback  | 
 | 
100 | 101 | }  | 
101 | 102 | 
 
  | 
102 | 103 | 
 
  | 
103 |  | -def _merge_input_ids_with_image_features(  | 
104 |  | -    image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id  | 
105 |  | -):  | 
106 |  | -    num_images, num_image_patches, embed_dim = image_features.shape  | 
107 |  | -    batch_size, sequence_length = input_ids.shape  | 
108 |  | -    special_image_token_mask = input_ids == image_token_index  | 
109 |  | -    # 1. Create a mask to know where special image tokens are  | 
110 |  | -    special_image_token_mask = input_ids == image_token_index  | 
111 |  | -    num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)  | 
112 |  | -    batch_indices, non_image_indices = torch.where(input_ids != image_token_index)  | 
113 |  | -    max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length  | 
114 |  | -    batch_indices, non_image_indices = torch.where(input_ids != image_token_index)  | 
115 |  | - | 
116 |  | -    # 2. Compute the positions where text should be written  | 
117 |  | -    # Calculate new positions for text tokens in merged image-text sequence.  | 
118 |  | -    # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.  | 
119 |  | -    # `torch.cumsum` computes how each image token shifts subsequent text token positions.  | 
120 |  | -    # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.  | 
121 |  | -    new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1  | 
122 |  | -    nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]  | 
123 |  | -    if left_padding:  | 
124 |  | -        new_token_positions += nb_image_pad[:, None]  # offset for left padding  | 
125 |  | -    text_to_overwrite = new_token_positions[batch_indices, non_image_indices]  | 
126 |  | - | 
127 |  | -    # 3. Create the full embedding, already padded to the maximum position  | 
128 |  | -    final_embedding = torch.zeros(  | 
129 |  | -        batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device  | 
130 |  | -    )  | 
131 |  | -    final_attention_mask = torch.zeros(  | 
132 |  | -        batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device  | 
133 |  | -    )  | 
134 |  | -    # In case the Vision model or the Language model has been offloaded to CPU, we need to manually  | 
135 |  | -    # set the corresponding tensors into their correct target device.  | 
136 |  | -    target_device = inputs_embeds.device  | 
137 |  | -    batch_indices, non_image_indices, text_to_overwrite = (  | 
138 |  | -        batch_indices.to(target_device),  | 
139 |  | -        non_image_indices.to(target_device),  | 
140 |  | -        text_to_overwrite.to(target_device),  | 
141 |  | -    )  | 
142 |  | -    attention_mask = attention_mask.to(target_device)  | 
143 |  | - | 
144 |  | -    # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]  | 
145 |  | -    # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features  | 
146 |  | -    final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]  | 
147 |  | -    final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]  | 
148 |  | - | 
149 |  | -    # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)  | 
150 |  | -    image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device)  | 
151 |  | -    image_to_overwrite[batch_indices, text_to_overwrite] = False  | 
152 |  | -    if left_padding:  | 
153 |  | -        image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)  | 
154 |  | -    else:  | 
155 |  | -        mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1  | 
156 |  | -        padding_mask = mask <= new_token_positions[:, -1:].to(target_device)  | 
157 |  | -        image_to_overwrite &= padding_mask  | 
158 |  | - | 
159 |  | -    if image_to_overwrite.sum() != image_features.shape[:-1].numel():  | 
160 |  | -        raise ValueError(  | 
161 |  | -            f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"  | 
162 |  | -            f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."  | 
163 |  | -        )  | 
164 |  | - | 
165 |  | -    batch_indices, pad_indices = torch.where(input_ids == pad_token_id)  | 
166 |  | -    final_attention_mask |= image_to_overwrite  | 
167 |  | -    position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)  | 
168 |  | - | 
169 |  | -    # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.  | 
170 |  | -    batch_indices, pad_indices = torch.where(input_ids == pad_token_id)  | 
171 |  | -    indices_to_mask = new_token_positions[batch_indices, pad_indices]  | 
172 |  | - | 
173 |  | -    final_embedding[batch_indices, indices_to_mask] = 0  | 
174 |  | - | 
175 |  | -    return final_embedding, final_attention_mask, position_ids  | 
176 |  | - | 
177 |  | - | 
178 |  | -def _text_encoder_custom_forward():  | 
179 |  | -    return  | 
180 |  | - | 
181 |  | - | 
182 | 104 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps  | 
183 | 105 | def retrieve_timesteps(  | 
184 | 106 |     scheduler,  | 
@@ -310,6 +232,13 @@ def __init__(  | 
310 | 232 |         self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4  | 
311 | 233 |         self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8  | 
312 | 234 |         self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)  | 
 | 235 | +        self.llava_processor = LlavaProcessor(  | 
 | 236 | +            self.image_processor,  | 
 | 237 | +            self.tokenizer,  | 
 | 238 | +            patch_size=self.text_encoder.config.vision_config.patch_size,  | 
 | 239 | +            vision_feature_select_strategy=self.text_encoder.config.vision_feature_select_strategy,  | 
 | 240 | +            num_additional_image_tokens=1,  | 
 | 241 | +        )  | 
313 | 242 | 
 
  | 
314 | 243 |     def _get_llama_prompt_embeds(  | 
315 | 244 |         self,  | 
@@ -358,30 +287,25 @@ def _get_llama_prompt_embeds(  | 
358 | 287 |         prompt_attention_mask = text_inputs.attention_mask.to(device=device)  | 
359 | 288 | 
 
  | 
360 | 289 |         image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)  | 
361 |  | -        input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids)  | 
362 |  | - | 
363 |  | -        inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(  | 
364 |  | -            image_embeds,  | 
365 |  | -            input_embeds,  | 
366 |  | -            text_input_ids,  | 
367 |  | -            prompt_attention_mask,  | 
368 |  | -            self.text_encoder.config.image_token_index,  | 
369 |  | -            self.text_encoder.pad_token_id,  | 
370 |  | -        )  | 
 | 290 | +        inputs = self.llava_processor(  | 
 | 291 | +            text=prompt,  | 
 | 292 | +            images=image,  | 
 | 293 | +            # max_length=max_sequence_length,  | 
 | 294 | +            padding="max_length",  | 
 | 295 | +            truncation=True,  | 
 | 296 | +            return_length=False,  | 
 | 297 | +            return_overflowing_tokens=False,  | 
 | 298 | +            return_attention_mask=True,  | 
 | 299 | +            return_tensors="pt",  | 
 | 300 | +        ).to(device)  | 
 | 301 | + | 
 | 302 | +        text_input_ids = inputs["input_ids"]  | 
 | 303 | +        prompt_attention_mask = inputs["attention_mask"]  | 
371 | 304 | 
 
  | 
372 |  | -        prompt_embeds = self.text_encoder.language_model(  | 
373 |  | -            attention_mask=attention_mask,  | 
374 |  | -            position_ids=position_ids,  | 
375 |  | -            inputs_embeds=inputs_embeds,  | 
376 |  | -        ).hidden_states[-(num_hidden_layers_to_skip + 1)]  | 
377 |  | -        """  | 
378 | 305 |         prompt_embeds = self.text_encoder(  | 
379 |  | -            input_ids=text_input_ids,  | 
380 |  | -            attention_mask=prompt_attention_mask,  | 
381 |  | -            pixel_values=image_embeds,  | 
 | 306 | +            **inputs,  | 
382 | 307 |             output_hidden_states=True,  | 
383 | 308 |         ).hidden_states[-(num_hidden_layers_to_skip + 1)]  | 
384 |  | -        """  | 
385 | 309 |         prompt_embeds = prompt_embeds.to(dtype=dtype)  | 
386 | 310 | 
 
  | 
387 | 311 |         image_emb_len = prompt_template.get("image_emb_len", 576)  | 
 | 
0 commit comments