Skip to content

Commit 655dcda

Browse files
committed
update
1 parent 77abad3 commit 655dcda

File tree

1 file changed

+150
-40
lines changed

1 file changed

+150
-40
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 150 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727

2828
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29+
from ...image_processor import PipelineImageInput
2930
from ...loaders import HunyuanVideoLoraLoaderMixin
3031
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
3132
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -75,15 +76,20 @@
7576

7677
DEFAULT_PROMPT_TEMPLATE = {
7778
"template": (
78-
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
79+
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
7980
"1. The main content and theme of the video."
8081
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
8182
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
8283
"4. background environment, light, style and atmosphere."
83-
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
84+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
8485
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
86+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
8587
),
86-
"crop_start": 95,
88+
"crop_start": 103,
89+
"image_emb_start": 5,
90+
"image_emb_end": 581,
91+
"image_emb_len": 576,
92+
"double_return_token_id": 271,
8793
}
8894

8995

@@ -147,6 +153,20 @@ def retrieve_timesteps(
147153
return timesteps, num_inference_steps
148154

149155

156+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
157+
def retrieve_latents(
158+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
159+
):
160+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
161+
return encoder_output.latent_dist.sample(generator)
162+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
163+
return encoder_output.latent_dist.mode()
164+
elif hasattr(encoder_output, "latents"):
165+
return encoder_output.latents
166+
else:
167+
raise AttributeError("Could not access latents of provided encoder_output")
168+
169+
150170
class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
151171
r"""
152172
Pipeline for image-to-video generation using HunyuanVideo.
@@ -197,6 +217,7 @@ def __init__(
197217
scheduler=scheduler,
198218
text_encoder_2=text_encoder_2,
199219
tokenizer_2=tokenizer_2,
220+
image_processor=image_processor,
200221
)
201222

202223
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
@@ -205,13 +226,15 @@ def __init__(
205226

206227
def _get_llama_prompt_embeds(
207228
self,
229+
image: torch.Tensor,
208230
prompt: Union[str, List[str]],
209231
prompt_template: Dict[str, Any],
210232
num_videos_per_prompt: int = 1,
211233
device: Optional[torch.device] = None,
212234
dtype: Optional[torch.dtype] = None,
213235
max_sequence_length: int = 256,
214236
num_hidden_layers_to_skip: int = 2,
237+
image_embed_interleave: int = 2,
215238
) -> Tuple[torch.Tensor, torch.Tensor]:
216239
device = device or self._execution_device
217240
dtype = dtype or self.text_encoder.dtype
@@ -232,8 +255,8 @@ def _get_llama_prompt_embeds(
232255
return_attention_mask=False,
233256
)
234257
crop_start = prompt_template_input["input_ids"].shape[-1]
235-
# Remove <|eot_id|> token and placeholder {}
236-
crop_start -= 2
258+
# Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {}
259+
crop_start -= 5
237260

238261
max_sequence_length += crop_start
239262
text_inputs = self.tokenizer(
@@ -249,16 +272,84 @@ def _get_llama_prompt_embeds(
249272
text_input_ids = text_inputs.input_ids.to(device=device)
250273
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
251274

275+
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
276+
252277
prompt_embeds = self.text_encoder(
253278
input_ids=text_input_ids,
254279
attention_mask=prompt_attention_mask,
280+
pixel_values=image_embeds,
255281
output_hidden_states=True,
256282
).hidden_states[-(num_hidden_layers_to_skip + 1)]
257283
prompt_embeds = prompt_embeds.to(dtype=dtype)
258284

285+
image_emb_len = prompt_template.get("image_emb_len", 576)
286+
image_emb_start = prompt_template.get("image_emb_start", 5)
287+
image_emb_end = prompt_template.get("image_emb_end", 581)
288+
double_return_token_id = prompt_template.get("double_return_token_id", 271)
289+
259290
if crop_start is not None and crop_start > 0:
260-
prompt_embeds = prompt_embeds[:, crop_start:]
261-
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
291+
text_crop_start = crop_start - 1 + image_emb_len
292+
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)
293+
294+
if last_double_return_token_indices.shape[0] == 3:
295+
# in case the prompt is too long
296+
last_double_return_token_indices = torch.cat(
297+
(last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]]))
298+
)
299+
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
300+
301+
last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[
302+
:, -1
303+
]
304+
batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1]
305+
assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4
306+
assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len
307+
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
308+
attention_mask_assistant_crop_end = last_double_return_token_indices
309+
310+
prompt_embed_list = []
311+
prompt_attention_mask_list = []
312+
image_embed_list = []
313+
image_attention_mask_list = []
314+
315+
for i in range(text_input_ids.shape[0]):
316+
prompt_embed_list.append(
317+
torch.cat(
318+
[
319+
prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()],
320+
prompt_embeds[i, assistant_crop_end[i].item() :],
321+
]
322+
)
323+
)
324+
prompt_attention_mask_list.append(
325+
torch.cat(
326+
[
327+
prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()],
328+
prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :],
329+
]
330+
)
331+
)
332+
image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end])
333+
image_attention_mask_list.append(
334+
torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype)
335+
)
336+
337+
prompt_embed_list = torch.stack(prompt_embed_list)
338+
prompt_attention_mask_list = torch.stack(prompt_attention_mask_list)
339+
image_embed_list = torch.stack(image_embed_list)
340+
image_attention_mask_list = torch.stack(image_attention_mask_list)
341+
342+
if image_embed_interleave < 6:
343+
image_embed_list = image_embed_list[:, ::image_embed_interleave, :]
344+
image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave]
345+
346+
assert (
347+
prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0]
348+
and image_embed_list.shape[0] == image_attention_mask_list.shape[0]
349+
)
350+
351+
prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1)
352+
prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1)
262353

263354
# duplicate text embeddings for each generation per prompt, using mps friendly method
264355
_, seq_len, _ = prompt_embeds.shape
@@ -310,6 +401,7 @@ def _get_clip_prompt_embeds(
310401

311402
def encode_prompt(
312403
self,
404+
image: torch.Tensor,
313405
prompt: Union[str, List[str]],
314406
prompt_2: Union[str, List[str]] = None,
315407
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
@@ -323,6 +415,7 @@ def encode_prompt(
323415
):
324416
if prompt_embeds is None:
325417
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
418+
image,
326419
prompt,
327420
prompt_template,
328421
num_videos_per_prompt,
@@ -393,6 +486,7 @@ def check_inputs(
393486

394487
def prepare_latents(
395488
self,
489+
image: torch.Tensor,
396490
batch_size: int,
397491
num_channels_latents: int = 32,
398492
height: int = 720,
@@ -403,24 +497,36 @@ def prepare_latents(
403497
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
404498
latents: Optional[torch.Tensor] = None,
405499
) -> torch.Tensor:
406-
if latents is not None:
407-
return latents.to(device=device, dtype=dtype)
408-
409-
shape = (
410-
batch_size,
411-
num_channels_latents,
412-
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
413-
int(height) // self.vae_scale_factor_spatial,
414-
int(width) // self.vae_scale_factor_spatial,
415-
)
416500
if isinstance(generator, list) and len(generator) != batch_size:
417501
raise ValueError(
418502
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
419503
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
420504
)
421505

422-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
423-
return latents
506+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
507+
latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial
508+
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
509+
510+
image = image.unsqueeze(2) # [B, C, 1, H, W]
511+
if isinstance(generator, list):
512+
image_latents = [
513+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
514+
]
515+
else:
516+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
517+
518+
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
519+
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
520+
521+
if latents is None:
522+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
523+
else:
524+
latents = latents.to(device=device, dtype=dtype)
525+
526+
t = torch.tensor([0.999]).to(device=device)
527+
latents = latents * t + image_latents * (1 - t)
528+
529+
return latents, image_latents
424530

425531
def enable_vae_slicing(self):
426532
r"""
@@ -475,6 +581,7 @@ def interrupt(self):
475581
@replace_example_docstring(EXAMPLE_DOC_STRING)
476582
def __call__(
477583
self,
584+
image: PipelineImageInput,
478585
prompt: Union[str, List[str]] = None,
479586
prompt_2: Union[str, List[str]] = None,
480587
negative_prompt: Union[str, List[str]] = None,
@@ -632,9 +739,30 @@ def __call__(
632739
else:
633740
batch_size = prompt_embeds.shape[0]
634741

635-
# 3. Encode input prompt
742+
# 3. Prepare latent variables
743+
vae_dtype = self.vae.dtype
744+
image = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
745+
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
746+
latents, image_latents = self.prepare_latents(
747+
image,
748+
batch_size * num_videos_per_prompt,
749+
num_channels_latents,
750+
height,
751+
width,
752+
num_frames,
753+
torch.float32,
754+
device,
755+
generator,
756+
latents,
757+
)
758+
image_latents[:, :, 1:] = 0
759+
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
760+
mask[:, :, 1:] = 0
761+
762+
# 4. Encode input prompt
636763
transformer_dtype = self.transformer.dtype
637764
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
765+
image=image,
638766
prompt=prompt,
639767
prompt_2=prompt_2,
640768
prompt_template=prompt_template,
@@ -651,6 +779,7 @@ def __call__(
651779

652780
if do_true_cfg:
653781
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
782+
image=torch.full_like(image, fill_value=-1),
654783
prompt=negative_prompt,
655784
prompt_2=negative_prompt_2,
656785
prompt_template=prompt_template,
@@ -669,23 +798,6 @@ def __call__(
669798
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
670799
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
671800

672-
# 5. Prepare latent variables
673-
num_channels_latents = self.transformer.config.in_channels
674-
latents = self.prepare_latents(
675-
batch_size * num_videos_per_prompt,
676-
num_channels_latents,
677-
height,
678-
width,
679-
num_frames,
680-
torch.float32,
681-
device,
682-
generator,
683-
latents,
684-
)
685-
686-
# 6. Prepare guidance condition
687-
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
688-
689801
# 7. Denoising loop
690802
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
691803
self._num_timesteps = len(timesteps)
@@ -696,7 +808,7 @@ def __call__(
696808
continue
697809

698810
self._current_timestep = t
699-
latent_model_input = latents.to(transformer_dtype)
811+
latent_model_input = torch.cat([latents, image_latents, mask], dim=1)
700812
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
701813
timestep = t.expand(latents.shape[0]).to(latents.dtype)
702814

@@ -706,7 +818,6 @@ def __call__(
706818
encoder_hidden_states=prompt_embeds,
707819
encoder_attention_mask=prompt_attention_mask,
708820
pooled_projections=pooled_prompt_embeds,
709-
guidance=guidance,
710821
attention_kwargs=attention_kwargs,
711822
return_dict=False,
712823
)[0]
@@ -718,7 +829,6 @@ def __call__(
718829
encoder_hidden_states=negative_prompt_embeds,
719830
encoder_attention_mask=negative_prompt_attention_mask,
720831
pooled_projections=negative_pooled_prompt_embeds,
721-
guidance=guidance,
722832
attention_kwargs=attention_kwargs,
723833
return_dict=False,
724834
)[0]

0 commit comments

Comments
 (0)