Skip to content

Commit 6e15e47

Browse files
committed
update
1 parent d4b6f6c commit 6e15e47

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,50 @@
101101
}
102102

103103

104+
def _expand_input_ids_with_image_tokens(
105+
text_input_ids,
106+
prompt_attention_mask,
107+
max_sequence_length,
108+
image_token_index,
109+
image_emb_len,
110+
image_emb_start,
111+
image_emb_end,
112+
pad_token_id,
113+
):
114+
special_image_token_mask = text_input_ids == image_token_index
115+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
116+
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
117+
118+
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
119+
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
120+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
121+
122+
expanded_input_ids = torch.full(
123+
(text_input_ids.shape[0], max_expanded_length),
124+
pad_token_id,
125+
dtype=text_input_ids.dtype,
126+
device=text_input_ids.device,
127+
)
128+
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
129+
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
130+
131+
expanded_attention_mask = torch.zeros(
132+
(text_input_ids.shape[0], max_expanded_length),
133+
dtype=prompt_attention_mask.dtype,
134+
device=prompt_attention_mask.device,
135+
)
136+
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
137+
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
138+
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
139+
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
140+
141+
return {
142+
"input_ids": expanded_input_ids,
143+
"attention_mask": expanded_attention_mask,
144+
"position_ids": position_ids,
145+
}
146+
147+
104148
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
105149
def retrieve_timesteps(
106150
scheduler,
@@ -259,6 +303,12 @@ def _get_llama_prompt_embeds(
259303
prompt = [prompt_template["template"].format(p) for p in prompt]
260304

261305
crop_start = prompt_template.get("crop_start", None)
306+
307+
image_emb_len = prompt_template.get("image_emb_len", 576)
308+
image_emb_start = prompt_template.get("image_emb_start", 5)
309+
image_emb_end = prompt_template.get("image_emb_end", 581)
310+
double_return_token_id = prompt_template.get("double_return_token_id", 271)
311+
262312
if crop_start is None:
263313
prompt_template_input = self.tokenizer(
264314
prompt_template["template"],
@@ -288,69 +338,25 @@ def _get_llama_prompt_embeds(
288338

289339
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
290340

291-
_, _, image_height, image_width = image_embeds.shape
292-
patch_size = self.text_encoder.config.vision_config.patch_size
293-
num_image_tokens = (image_height // patch_size) * (image_width // patch_size)
294-
if self.text_encoder.config.vision_config.vision_feature_select_strategy == "default":
295-
num_image_tokens -= 1
296-
297341
image_token_index = self.text_encoder.config.image_token_index
298342
pad_token_id = self.text_encoder.config.pad_token_id
299-
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
300-
301-
special_image_token_mask = text_input_ids == image_token_index
302-
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
303-
304-
max_expanded_length = max_sequence_length + (
305-
num_special_image_tokens.max() * (prompt_template["image_emb_len"] - 1)
306-
)
307-
new_token_positions = (
308-
torch.cumsum((special_image_token_mask * (prompt_template["image_emb_len"] - 1) + 1), -1) - 1
309-
)
310-
nb_image_pad = max_expanded_length - 1 - new_token_positions[:, -1]
311-
if left_padding:
312-
new_token_positions += nb_image_pad[:, None]
313-
314-
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
315-
316-
expanded_input_ids = torch.full(
317-
(batch_size, max_expanded_length), pad_token_id, dtype=text_input_ids.dtype, device=device
318-
)
319-
expanded_attention_mask = torch.ones(
320-
(batch_size, max_expanded_length), dtype=prompt_attention_mask.dtype, device=device
343+
expanded_inputs = _expand_input_ids_with_image_tokens(
344+
text_input_ids,
345+
prompt_attention_mask,
346+
max_sequence_length,
347+
image_token_index,
348+
image_emb_len,
349+
image_emb_start,
350+
image_emb_end,
351+
pad_token_id,
321352
)
322-
323-
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
324-
expanded_inputs_ids[batch_indices, prompt_template["image_emb_start"] : prompt_template["image_emb_end"]] = (
325-
image_token_index
326-
)
327-
328-
inputs = self.llava_processor(
329-
text=prompt,
330-
images=image,
331-
# max_length=max_sequence_length,
332-
padding="max_length",
333-
truncation=True,
334-
return_length=False,
335-
return_overflowing_tokens=False,
336-
return_attention_mask=True,
337-
return_tensors="pt",
338-
).to(device)
339-
340-
text_input_ids = inputs["input_ids"]
341-
prompt_attention_mask = inputs["attention_mask"]
342-
343353
prompt_embeds = self.text_encoder(
344-
**inputs,
354+
**expanded_inputs,
355+
pixel_value=image_embeds,
345356
output_hidden_states=True,
346357
).hidden_states[-(num_hidden_layers_to_skip + 1)]
347358
prompt_embeds = prompt_embeds.to(dtype=dtype)
348359

349-
image_emb_len = prompt_template.get("image_emb_len", 576)
350-
image_emb_start = prompt_template.get("image_emb_start", 5)
351-
image_emb_end = prompt_template.get("image_emb_end", 581)
352-
double_return_token_id = prompt_template.get("double_return_token_id", 271)
353-
354360
if crop_start is not None and crop_start > 0:
355361
text_crop_start = crop_start - 1 + image_emb_len
356362
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)

0 commit comments

Comments
 (0)