Skip to content

Commit d4b6f6c

Browse files
committed
update
1 parent 169d45c commit d4b6f6c

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,44 @@ def _get_llama_prompt_embeds(
287287
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
288288

289289
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
290+
290291
_, _, image_height, image_width = image_embeds.shape
291292
patch_size = self.text_encoder.config.vision_config.patch_size
292293
num_image_tokens = (image_height // patch_size) * (image_width // patch_size)
293294
if self.text_encoder.config.vision_config.vision_feature_select_strategy == "default":
294295
num_image_tokens -= 1
295296

297+
image_token_index = self.text_encoder.config.image_token_index
298+
pad_token_id = self.text_encoder.config.pad_token_id
299+
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
300+
301+
special_image_token_mask = text_input_ids == image_token_index
302+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
303+
304+
max_expanded_length = max_sequence_length + (
305+
num_special_image_tokens.max() * (prompt_template["image_emb_len"] - 1)
306+
)
307+
new_token_positions = (
308+
torch.cumsum((special_image_token_mask * (prompt_template["image_emb_len"] - 1) + 1), -1) - 1
309+
)
310+
nb_image_pad = max_expanded_length - 1 - new_token_positions[:, -1]
311+
if left_padding:
312+
new_token_positions += nb_image_pad[:, None]
313+
314+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
315+
316+
expanded_input_ids = torch.full(
317+
(batch_size, max_expanded_length), pad_token_id, dtype=text_input_ids.dtype, device=device
318+
)
319+
expanded_attention_mask = torch.ones(
320+
(batch_size, max_expanded_length), dtype=prompt_attention_mask.dtype, device=device
321+
)
322+
323+
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
324+
expanded_inputs_ids[batch_indices, prompt_template["image_emb_start"] : prompt_template["image_emb_end"]] = (
325+
image_token_index
326+
)
327+
296328
inputs = self.llava_processor(
297329
text=prompt,
298330
images=image,

0 commit comments

Comments
 (0)