Skip to content

Commit 671149e

Browse files
authored
[HunyuanVideo1.5] support step-distilled (#12802)
* support step-distilled * style
1 parent f67639b commit 671149e

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

scripts/convert_hunyuan_video1_5_to_diffusers.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@
6969
"target_size": 960,
7070
"task_type": "i2v",
7171
},
72+
"480p_i2v_step_distilled": {
73+
"target_size": 640,
74+
"task_type": "i2v",
75+
"use_meanflow": True,
76+
},
7277
}
7378

7479
SCHEDULER_CONFIGS = {
@@ -93,6 +98,9 @@
9398
"720p_i2v_distilled": {
9499
"shift": 7.0,
95100
},
101+
"480p_i2v_step_distilled": {
102+
"shift": 7.0,
103+
},
96104
}
97105

98106
GUIDANCE_CONFIGS = {
@@ -117,6 +125,9 @@
117125
"720p_i2v_distilled": {
118126
"guidance_scale": 1.0,
119127
},
128+
"480p_i2v_step_distilled": {
129+
"guidance_scale": 1.0,
130+
},
120131
}
121132

122133

@@ -126,7 +137,7 @@ def swap_scale_shift(weight):
126137
return new_weight
127138

128139

129-
def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
140+
def convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=None):
130141
"""
131142
Convert HunyuanVideo 1.5 original checkpoint to Diffusers format.
132143
"""
@@ -142,6 +153,20 @@ def convert_hyvideo15_transformer_to_diffusers(original_state_dict):
142153
)
143154
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("time_in.mlp.2.bias")
144155

156+
if config.use_meanflow:
157+
converted_state_dict["time_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
158+
"time_r_in.mlp.0.weight"
159+
)
160+
converted_state_dict["time_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
161+
"time_r_in.mlp.0.bias"
162+
)
163+
converted_state_dict["time_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
164+
"time_r_in.mlp.2.weight"
165+
)
166+
converted_state_dict["time_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
167+
"time_r_in.mlp.2.bias"
168+
)
169+
145170
# 2. context_embedder.time_text_embed.timestep_embedder <- txt_in.t_embedder
146171
converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
147172
original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
@@ -627,7 +652,7 @@ def convert_transformer(args):
627652
config = TRANSFORMER_CONFIGS[args.transformer_type]
628653
with init_empty_weights():
629654
transformer = HunyuanVideo15Transformer3DModel(**config)
630-
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict)
655+
state_dict = convert_hyvideo15_transformer_to_diffusers(original_state_dict, config=transformer.config)
631656
transformer.load_state_dict(state_dict, strict=True, assign=True)
632657

633658
return transformer

src/diffusers/models/transformers/transformer_hunyuan_video15.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,19 +184,32 @@ class HunyuanVideo15TimeEmbedding(nn.Module):
184184
The dimension of the output embedding.
185185
"""
186186

187-
def __init__(self, embedding_dim: int):
187+
def __init__(self, embedding_dim: int, use_meanflow: bool = False):
188188
super().__init__()
189189

190190
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
191191
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
192192

193+
self.use_meanflow = use_meanflow
194+
self.time_proj_r = None
195+
self.timestep_embedder_r = None
196+
if use_meanflow:
197+
self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
198+
self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
199+
193200
def forward(
194201
self,
195202
timestep: torch.Tensor,
203+
timestep_r: Optional[torch.Tensor] = None,
196204
) -> torch.Tensor:
197205
timesteps_proj = self.time_proj(timestep)
198206
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
199207

208+
if timestep_r is not None:
209+
timesteps_proj_r = self.time_proj_r(timestep_r)
210+
timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
211+
timesteps_emb = timesteps_emb + timesteps_emb_r
212+
200213
return timesteps_emb
201214

202215

@@ -567,6 +580,7 @@ def __init__(
567580
# YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
568581
target_size: int = 640, # did not name sample_size since it is in pixel spaces
569582
task_type: str = "i2v",
583+
use_meanflow: bool = False,
570584
) -> None:
571585
super().__init__()
572586

@@ -582,7 +596,7 @@ def __init__(
582596
)
583597
self.context_embedder_2 = HunyuanVideo15ByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
584598

585-
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim)
599+
self.time_embed = HunyuanVideo15TimeEmbedding(inner_dim, use_meanflow=use_meanflow)
586600

587601
self.cond_type_embed = nn.Embedding(3, inner_dim)
588602

@@ -612,6 +626,7 @@ def forward(
612626
timestep: torch.LongTensor,
613627
encoder_hidden_states: torch.Tensor,
614628
encoder_attention_mask: torch.Tensor,
629+
timestep_r: Optional[torch.LongTensor] = None,
615630
encoder_hidden_states_2: Optional[torch.Tensor] = None,
616631
encoder_attention_mask_2: Optional[torch.Tensor] = None,
617632
image_embeds: Optional[torch.Tensor] = None,
@@ -643,7 +658,7 @@ def forward(
643658
image_rotary_emb = self.rope(hidden_states)
644659

645660
# 2. Conditional embeddings
646-
temb = self.time_embed(timestep)
661+
temb = self.time_embed(timestep, timestep_r=timestep_r)
647662

648663
hidden_states = self.x_embedder(hidden_states)
649664

src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5_image2video.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,15 @@ def __call__(
852852
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
853853
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
854854

855+
if self.transformer.config.use_meanflow:
856+
if i == len(timesteps) - 1:
857+
timestep_r = torch.tensor([0.0], device=device)
858+
else:
859+
timestep_r = timesteps[i + 1]
860+
timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
861+
else:
862+
timestep_r = None
863+
855864
# Step 1: Collect model inputs needed for the guidance method
856865
# conditional inputs should always be first element in the tuple
857866
guider_inputs = {
@@ -893,6 +902,7 @@ def __call__(
893902
hidden_states=latent_model_input,
894903
image_embeds=image_embeds,
895904
timestep=timestep,
905+
timestep_r=timestep_r,
896906
attention_kwargs=self.attention_kwargs,
897907
return_dict=False,
898908
**cond_kwargs,

0 commit comments

Comments
 (0)