Skip to content

Commit 0f599d9

Browse files
committed
update
1 parent b8093e6 commit 0f599d9

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,85 @@
100100
}
101101

102102

103+
def _merge_input_ids_with_image_features(
104+
image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id
105+
):
106+
num_images, num_image_patches, embed_dim = image_features.shape
107+
batch_size, sequence_length = input_ids.shape
108+
special_image_token_mask = input_ids == image_token_index
109+
# 1. Create a mask to know where special image tokens are
110+
special_image_token_mask = input_ids == image_token_index
111+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
112+
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
113+
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
114+
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
115+
116+
# 2. Compute the positions where text should be written
117+
# Calculate new positions for text tokens in merged image-text sequence.
118+
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
119+
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
120+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
121+
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
122+
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
123+
if left_padding:
124+
new_token_positions += nb_image_pad[:, None] # offset for left padding
125+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
126+
127+
# 3. Create the full embedding, already padded to the maximum position
128+
final_embedding = torch.zeros(
129+
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
130+
)
131+
final_attention_mask = torch.zeros(
132+
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
133+
)
134+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
135+
# set the corresponding tensors into their correct target device.
136+
target_device = inputs_embeds.device
137+
batch_indices, non_image_indices, text_to_overwrite = (
138+
batch_indices.to(target_device),
139+
non_image_indices.to(target_device),
140+
text_to_overwrite.to(target_device),
141+
)
142+
attention_mask = attention_mask.to(target_device)
143+
144+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
145+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
146+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
147+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
148+
149+
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
150+
image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device)
151+
image_to_overwrite[batch_indices, text_to_overwrite] = False
152+
if left_padding:
153+
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
154+
else:
155+
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
156+
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
157+
image_to_overwrite &= padding_mask
158+
159+
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
160+
raise ValueError(
161+
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
162+
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
163+
)
164+
165+
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
166+
final_attention_mask |= image_to_overwrite
167+
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
168+
169+
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
170+
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
171+
indices_to_mask = new_token_positions[batch_indices, pad_indices]
172+
173+
final_embedding[batch_indices, indices_to_mask] = 0
174+
175+
return final_embedding, final_attention_mask, position_ids
176+
177+
178+
def _text_encoder_custom_forward():
179+
return
180+
181+
103182
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104183
def retrieve_timesteps(
105184
scheduler,
@@ -279,13 +358,30 @@ def _get_llama_prompt_embeds(
279358
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
280359

281360
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
361+
input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids)
362+
363+
inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(
364+
image_embeds,
365+
input_embeds,
366+
text_input_ids,
367+
prompt_attention_mask,
368+
self.text_encoder.config.image_token_index,
369+
self.text_encoder.pad_token_id,
370+
)
282371

372+
prompt_embeds = self.text_encoder.language_model(
373+
attention_mask=attention_mask,
374+
position_ids=position_ids,
375+
inputs_embeds=inputs_embeds,
376+
).hidden_states[-(num_hidden_layers_to_skip + 1)]
377+
"""
283378
prompt_embeds = self.text_encoder(
284379
input_ids=text_input_ids,
285380
attention_mask=prompt_attention_mask,
286381
pixel_values=image_embeds,
287382
output_hidden_states=True,
288383
).hidden_states[-(num_hidden_layers_to_skip + 1)]
384+
"""
289385
prompt_embeds = prompt_embeds.to(dtype=dtype)
290386

291387
image_emb_len = prompt_template.get("image_emb_len", 576)

0 commit comments

Comments
 (0)