diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index 774b72e6c7c1..d3c8a3539b98 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -100,6 +100,50 @@ } +def _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, +): + special_image_token_mask = text_input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) + + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + expanded_input_ids = torch.full( + (text_input_ids.shape[0], max_expanded_length), + pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index + + expanded_attention_mask = torch.zeros( + (text_input_ids.shape[0], max_expanded_length), + dtype=prompt_attention_mask.dtype, + device=prompt_attention_mask.device, + ) + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) + + return { + "input_ids": expanded_input_ids, + "attention_mask": expanded_attention_mask, + "position_ids": position_ids, + } + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -251,6 +295,12 @@ def _get_llama_prompt_embeds( prompt = [prompt_template["template"].format(p) for p in prompt] crop_start = prompt_template.get("crop_start", None) + + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + if crop_start is None: prompt_template_input = self.tokenizer( prompt_template["template"], @@ -280,19 +330,25 @@ def _get_llama_prompt_embeds( image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + image_token_index = self.text_encoder.config.image_token_index + pad_token_id = self.text_encoder.config.pad_token_id + expanded_inputs = _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, + ) prompt_embeds = self.text_encoder( - input_ids=text_input_ids, - attention_mask=prompt_attention_mask, - pixel_values=image_embeds, + **expanded_inputs, + pixel_value=image_embeds, output_hidden_states=True, ).hidden_states[-(num_hidden_layers_to_skip + 1)] prompt_embeds = prompt_embeds.to(dtype=dtype) - image_emb_len = prompt_template.get("image_emb_len", 576) - image_emb_start = prompt_template.get("image_emb_start", 5) - image_emb_end = prompt_template.get("image_emb_end", 581) - double_return_token_id = prompt_template.get("double_return_token_id", 271) - if crop_start is not None and crop_start > 0: text_crop_start = crop_start - 1 + image_emb_len batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)