Skip to content

Commit e978876

Browse files
committed
update
1 parent 1e6ada6 commit e978876

File tree

2 files changed

+9
-22
lines changed

2 files changed

+9
-22
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def get_args():
317317
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
318318
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
319319
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
320-
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_2_path)
320+
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
321321

322322
pipe = HunyuanVideoImageToVideoPipeline(
323323
transformer=transformer,

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1717

1818
import numpy as np
19+
import PIL.Image
1920
import torch
2021
from transformers import (
2122
CLIPImageProcessor,
@@ -26,7 +27,6 @@
2627
)
2728

2829
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
29-
from ...image_processor import PipelineImageInput
3030
from ...loaders import HunyuanVideoLoraLoaderMixin
3131
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
3232
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -220,6 +220,7 @@ def __init__(
220220
image_processor=image_processor,
221221
)
222222

223+
self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986
223224
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
224225
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
225226
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -240,8 +241,6 @@ def _get_llama_prompt_embeds(
240241
dtype = dtype or self.text_encoder.dtype
241242

242243
prompt = [prompt] if isinstance(prompt, str) else prompt
243-
batch_size = len(prompt)
244-
245244
prompt = [prompt_template["template"].format(p) for p in prompt]
246245

247246
crop_start = prompt_template.get("crop_start", None)
@@ -351,13 +350,6 @@ def _get_llama_prompt_embeds(
351350
prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1)
352351
prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1)
353352

354-
# duplicate text embeddings for each generation per prompt, using mps friendly method
355-
_, seq_len, _ = prompt_embeds.shape
356-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
357-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
358-
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
359-
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
360-
361353
return prompt_embeds, prompt_attention_mask
362354

363355
def _get_clip_prompt_embeds(
@@ -372,7 +364,6 @@ def _get_clip_prompt_embeds(
372364
dtype = dtype or self.text_encoder_2.dtype
373365

374366
prompt = [prompt] if isinstance(prompt, str) else prompt
375-
batch_size = len(prompt)
376367

377368
text_inputs = self.tokenizer_2(
378369
prompt,
@@ -392,11 +383,6 @@ def _get_clip_prompt_embeds(
392383
)
393384

394385
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
395-
396-
# duplicate text embeddings for each generation per prompt, using mps friendly method
397-
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
398-
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
399-
400386
return prompt_embeds
401387

402388
def encode_prompt(
@@ -581,7 +567,7 @@ def interrupt(self):
581567
@replace_example_docstring(EXAMPLE_DOC_STRING)
582568
def __call__(
583569
self,
584-
image: PipelineImageInput,
570+
image: PIL.Image.Image,
585571
prompt: Union[str, List[str]] = None,
586572
prompt_2: Union[str, List[str]] = None,
587573
negative_prompt: Union[str, List[str]] = None,
@@ -741,10 +727,10 @@ def __call__(
741727

742728
# 3. Prepare latent variables
743729
vae_dtype = self.vae.dtype
744-
image = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
730+
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
745731
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
746732
latents, image_latents = self.prepare_latents(
747-
image,
733+
image_tensor,
748734
batch_size * num_videos_per_prompt,
749735
num_channels_latents,
750736
height,
@@ -778,8 +764,9 @@ def __call__(
778764
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
779765

780766
if do_true_cfg:
767+
black_image = PIL.Image.new("RGB", (width, height), 0)
781768
negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt(
782-
image=torch.full_like(image, fill_value=-1),
769+
image=black_image,
783770
prompt=negative_prompt,
784771
prompt_2=negative_prompt_2,
785772
prompt_template=prompt_template,
@@ -808,7 +795,7 @@ def __call__(
808795
continue
809796

810797
self._current_timestep = t
811-
latent_model_input = torch.cat([latents, image_latents, mask], dim=1)
798+
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
812799
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
813800
timestep = t.expand(latents.shape[0]).to(latents.dtype)
814801

0 commit comments

Comments
 (0)