Skip to content

Commit 5c669f8

Browse files
committed
update
1 parent 0f599d9 commit 5c669f8

File tree

1 file changed

+23
-99
lines changed

1 file changed

+23
-99
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 23 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CLIPTokenizer,
2525
LlamaTokenizerFast,
2626
LlavaForConditionalGeneration,
27+
LlavaProcessor,
2728
)
2829

2930
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -100,85 +101,6 @@
100101
}
101102

102103

103-
def _merge_input_ids_with_image_features(
104-
image_features, inputs_embeds, input_ids, attention_mask, image_token_index, pad_token_id
105-
):
106-
num_images, num_image_patches, embed_dim = image_features.shape
107-
batch_size, sequence_length = input_ids.shape
108-
special_image_token_mask = input_ids == image_token_index
109-
# 1. Create a mask to know where special image tokens are
110-
special_image_token_mask = input_ids == image_token_index
111-
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
112-
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
113-
max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
114-
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)
115-
116-
# 2. Compute the positions where text should be written
117-
# Calculate new positions for text tokens in merged image-text sequence.
118-
# `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
119-
# `torch.cumsum` computes how each image token shifts subsequent text token positions.
120-
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
121-
new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
122-
nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
123-
if left_padding:
124-
new_token_positions += nb_image_pad[:, None] # offset for left padding
125-
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
126-
127-
# 3. Create the full embedding, already padded to the maximum position
128-
final_embedding = torch.zeros(
129-
batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
130-
)
131-
final_attention_mask = torch.zeros(
132-
batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
133-
)
134-
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
135-
# set the corresponding tensors into their correct target device.
136-
target_device = inputs_embeds.device
137-
batch_indices, non_image_indices, text_to_overwrite = (
138-
batch_indices.to(target_device),
139-
non_image_indices.to(target_device),
140-
text_to_overwrite.to(target_device),
141-
)
142-
attention_mask = attention_mask.to(target_device)
143-
144-
# 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
145-
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
146-
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
147-
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
148-
149-
# 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
150-
image_to_overwrite = torch.full((batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device)
151-
image_to_overwrite[batch_indices, text_to_overwrite] = False
152-
if left_padding:
153-
image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
154-
else:
155-
mask = torch.ones_like(image_to_overwrite, dtype=torch.bool).cumsum(-1) - 1
156-
padding_mask = mask <= new_token_positions[:, -1:].to(target_device)
157-
image_to_overwrite &= padding_mask
158-
159-
if image_to_overwrite.sum() != image_features.shape[:-1].numel():
160-
raise ValueError(
161-
f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
162-
f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
163-
)
164-
165-
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
166-
final_attention_mask |= image_to_overwrite
167-
position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
168-
169-
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
170-
batch_indices, pad_indices = torch.where(input_ids == pad_token_id)
171-
indices_to_mask = new_token_positions[batch_indices, pad_indices]
172-
173-
final_embedding[batch_indices, indices_to_mask] = 0
174-
175-
return final_embedding, final_attention_mask, position_ids
176-
177-
178-
def _text_encoder_custom_forward():
179-
return
180-
181-
182104
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
183105
def retrieve_timesteps(
184106
scheduler,
@@ -310,6 +232,13 @@ def __init__(
310232
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
311233
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
312234
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
235+
self.llava_processor = LlavaProcessor(
236+
self.image_processor,
237+
self.tokenizer,
238+
patch_size=self.text_encoder.config.vision_config.patch_size,
239+
vision_feature_select_strategy=self.text_encoder.config.vision_feature_select_strategy,
240+
num_additional_image_tokens=1,
241+
)
313242

314243
def _get_llama_prompt_embeds(
315244
self,
@@ -358,30 +287,25 @@ def _get_llama_prompt_embeds(
358287
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
359288

360289
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
361-
input_embeds = self.text_encoder.get_input_embeddings()(text_input_ids)
362-
363-
inputs_embeds, attention_mask, position_ids = _merge_input_ids_with_image_features(
364-
image_embeds,
365-
input_embeds,
366-
text_input_ids,
367-
prompt_attention_mask,
368-
self.text_encoder.config.image_token_index,
369-
self.text_encoder.pad_token_id,
370-
)
290+
inputs = self.llava_processor(
291+
text=prompt,
292+
images=image,
293+
# max_length=max_sequence_length,
294+
padding="max_length",
295+
truncation=True,
296+
return_length=False,
297+
return_overflowing_tokens=False,
298+
return_attention_mask=True,
299+
return_tensors="pt",
300+
).to(device)
301+
302+
text_input_ids = inputs["input_ids"]
303+
prompt_attention_mask = inputs["attention_mask"]
371304

372-
prompt_embeds = self.text_encoder.language_model(
373-
attention_mask=attention_mask,
374-
position_ids=position_ids,
375-
inputs_embeds=inputs_embeds,
376-
).hidden_states[-(num_hidden_layers_to_skip + 1)]
377-
"""
378305
prompt_embeds = self.text_encoder(
379-
input_ids=text_input_ids,
380-
attention_mask=prompt_attention_mask,
381-
pixel_values=image_embeds,
306+
**inputs,
382307
output_hidden_states=True,
383308
).hidden_states[-(num_hidden_layers_to_skip + 1)]
384-
"""
385309
prompt_embeds = prompt_embeds.to(dtype=dtype)
386310

387311
image_emb_len = prompt_template.get("image_emb_len", 576)

0 commit comments

Comments
 (0)