Skip to content

Commit af24bea

Browse files
committed
update
1 parent 5da0839 commit af24bea

File tree

3 files changed

+62
-21
lines changed

3 files changed

+62
-21
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,9 @@ def remap_single_transformer_blocks_(key, state_dict):
160160
"pooled_projection_dim": 768,
161161
"rope_theta": 256.0,
162162
"rope_axes_dim": (16, 56, 56),
163+
"image_condition_type": None,
163164
},
164-
"HYVideo-T/2-I2V": {
165+
"HYVideo-T/2-I2V-33ch": {
165166
"in_channels": 16 * 2 + 1,
166167
"out_channels": 16,
167168
"num_attention_heads": 24,
@@ -178,6 +179,26 @@ def remap_single_transformer_blocks_(key, state_dict):
178179
"pooled_projection_dim": 768,
179180
"rope_theta": 256.0,
180181
"rope_axes_dim": (16, 56, 56),
182+
"image_condition_type": "latent_concat",
183+
},
184+
"HYVideo-T/2-I2V-16ch": {
185+
"in_channels": 16,
186+
"out_channels": 16,
187+
"num_attention_heads": 24,
188+
"attention_head_dim": 128,
189+
"num_layers": 20,
190+
"num_single_layers": 40,
191+
"num_refiner_layers": 2,
192+
"mlp_ratio": 4.0,
193+
"patch_size": 2,
194+
"patch_size_t": 1,
195+
"qk_norm": "rms_norm",
196+
"guidance_embeds": True,
197+
"text_embed_dim": 4096,
198+
"pooled_projection_dim": 768,
199+
"rope_theta": 256.0,
200+
"rope_axes_dim": (16, 56, 56),
201+
"image_condition_type": "token_replace",
181202
},
182203
}
183204

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,8 @@ def forward(
228228
)
229229

230230

231-
class HunyuanVideoTimestepTextProjEmbeddings(nn.Module):
232-
def __init__(self, embedding_dim: int, pooled_projection_dim: int, image_condition_type: Optional[str] = None):
231+
class HunyuanVideoConditionEmbedding(nn.Module):
232+
def __init__(self, embedding_dim: int, pooled_projection_dim: int, guidance_embeds: bool, image_condition_type: Optional[str] = None):
233233
super().__init__()
234234

235235
self.image_condition_type = image_condition_type
@@ -238,7 +238,11 @@ def __init__(self, embedding_dim: int, pooled_projection_dim: int, image_conditi
238238
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
239239
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
240240

241-
def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
241+
self.guidance_embedder = None
242+
if guidance_embeds:
243+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
244+
245+
def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
242246
timesteps_proj = self.time_proj(timestep)
243247
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
244248
pooled_projections = self.text_embedder(pooled_projection)
@@ -248,8 +252,13 @@ def forward(self, timestep: torch.Tensor, pooled_projection: torch.Tensor) -> Tu
248252
if self.image_condition_type == "token_replace":
249253
token_replace_timestep = torch.zeros_like(timestep)
250254
token_replace_proj = self.time_proj(token_replace_timestep)
251-
token_replace_emb = self.timestep_embedder(token_replace_proj)
252-
token_replace_emb = token_replace_emb + conditioning
255+
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
256+
token_replace_emb = token_replace_emb + pooled_projections
257+
258+
if self.guidance_embedder is not None:
259+
guidance_proj = self.time_proj(guidance)
260+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
261+
conditioning = conditioning + guidance_emb
253262

254263
return conditioning, token_replace_emb
255264

@@ -665,7 +674,7 @@ def forward(
665674

666675
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
667676
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
668-
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
677+
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
669678
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
670679

671680
# 4. Feed-forward
@@ -717,6 +726,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
717726
The value of theta to use in the RoPE layer.
718727
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
719728
The dimensions of the axes to use in the RoPE layer.
729+
image_condition_type (`str`, *optional*, defaults to `None`):
730+
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
731+
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
732+
tokens in the latent stream and apply conditioning.
720733
"""
721734

722735
_supports_gradient_checkpointing = True
@@ -761,12 +774,9 @@ def __init__(
761774
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
762775
)
763776

764-
if guidance_embeds:
765-
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
766-
else:
767-
self.time_text_embed = HunyuanVideoTimestepTextProjEmbeddings(
768-
inner_dim, pooled_projection_dim, image_condition_type
769-
)
777+
self.time_text_embed = HunyuanVideoConditionEmbedding(
778+
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
779+
)
770780

771781
# 2. RoPE
772782
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
@@ -904,10 +914,7 @@ def forward(
904914
image_rotary_emb = self.rope(hidden_states)
905915

906916
# 2. Conditional embeddings
907-
if self.config.guidance_embeds:
908-
temb = self.time_text_embed(timestep, guidance, pooled_projections)
909-
else:
910-
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections)
917+
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
911918

912919
hidden_states = self.x_embedder(hidden_states)
913920
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
5555
>>> from diffusers.utils import load_image, export_to_video
5656
57+
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
5758
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
5859
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
5960
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
@@ -69,7 +70,12 @@
6970
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
7071
... )
7172
72-
>>> output = pipe(image=image, prompt=prompt).frames[0]
73+
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
74+
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
75+
76+
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
77+
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
78+
7379
>>> export_to_video(output, "output.mp4", fps=15)
7480
```
7581
"""
@@ -804,10 +810,15 @@ def __call__(
804810
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
805811
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
806812

807-
# 4. Prepare timesteps
813+
# 5. Prepare timesteps
808814
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
809815
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
810816

817+
# 6. Prepare guidance condition
818+
guidance = None
819+
if self.transformer.config.guidance_embeds:
820+
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
821+
811822
# 7. Denoising loop
812823
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
813824
self._num_timesteps = len(timesteps)
@@ -824,14 +835,15 @@ def __call__(
824835
if image_condition_type == "latent_concat":
825836
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
826837
elif image_condition_type == "token_replace":
827-
latent_model_input = latents.to(transformer_dtype)
838+
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
828839

829840
noise_pred = self.transformer(
830841
hidden_states=latent_model_input,
831842
timestep=timestep,
832843
encoder_hidden_states=prompt_embeds,
833844
encoder_attention_mask=prompt_attention_mask,
834845
pooled_projections=pooled_prompt_embeds,
846+
guidance=guidance,
835847
attention_kwargs=attention_kwargs,
836848
return_dict=False,
837849
)[0]
@@ -843,6 +855,7 @@ def __call__(
843855
encoder_hidden_states=negative_prompt_embeds,
844856
encoder_attention_mask=negative_prompt_attention_mask,
845857
pooled_projections=negative_pooled_prompt_embeds,
858+
guidance=guidance,
846859
attention_kwargs=attention_kwargs,
847860
return_dict=False,
848861
)[0]
@@ -855,7 +868,7 @@ def __call__(
855868
latents = latents = self.scheduler.step(
856869
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
857870
)[0]
858-
latents = torch.cat([image_latents, latents])
871+
latents = torch.cat([image_latents, latents], dim=2)
859872

860873
if callback_on_step_end is not None:
861874
callback_kwargs = {}

0 commit comments

Comments
 (0)