Skip to content

Commit b6c7ae0

Browse files
committed
convert
1 parent 16778b1 commit b6c7ae0

File tree

2 files changed

+105
-122
lines changed

2 files changed

+105
-122
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from accelerate import init_empty_weights
66
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
77

8-
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel, HunyuanVideoPipeline
8+
from diffusers import AutoencoderKLHunyuanVideo, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
99

1010

1111
def remap_norm_scale_shift_(key, state_dict):
@@ -15,6 +15,23 @@ def remap_norm_scale_shift_(key, state_dict):
1515
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
1616

1717

18+
def remap_token_refiner_blocks_(key, state_dict):
19+
def rename_key(key):
20+
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
21+
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
22+
return new_key
23+
24+
if "self_attn_qkv" in key:
25+
weight = state_dict.pop(key)
26+
to_q, to_k, to_v = weight.chunk(3, dim=0)
27+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
28+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
29+
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
30+
31+
else:
32+
state_dict[rename_key(key)] = state_dict.pop(key)
33+
34+
1835
def remap_img_attn_qkv_(key, state_dict):
1936
weight = state_dict.pop(key)
2037
to_q, to_k, to_v = weight.chunk(3, dim=0)
@@ -31,14 +48,6 @@ def remap_txt_attn_qkv_(key, state_dict):
3148
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
3249

3350

34-
def remap_self_attn_qkv_(key, state_dict):
35-
weight = state_dict.pop(key)
36-
to_q, to_k, to_v = weight.chunk(3, dim=0)
37-
state_dict[key.replace("self_attn_qkv", "attn.to_q")] = to_q
38-
state_dict[key.replace("self_attn_qkv", "attn.to_k")] = to_k
39-
state_dict[key.replace("self_attn_qkv", "attn.to_v")] = to_v
40-
41-
4251
def remap_single_transformer_blocks_(key, state_dict):
4352
hidden_size = 3072
4453

@@ -71,16 +80,16 @@ def remap_single_transformer_blocks_(key, state_dict):
7180

7281

7382
TRANSFORMER_KEYS_RENAME_DICT = {
74-
# "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
75-
# "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
76-
# "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
77-
# "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
78-
# "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
79-
# "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
80-
"txt_in.t_embedder": "txt_in.time_embed",
81-
"txt_in.c_embedder": "txt_in.context_embed",
83+
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
84+
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
85+
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
86+
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
87+
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
88+
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
89+
"txt_in.t_embedder.mlp.0": "txt_in.time_text_embed.timestep_embedder.linear_1",
90+
"txt_in.t_embedder.mlp.2": "txt_in.time_text_embed.timestep_embedder.linear_2",
91+
"txt_in.c_embedder": "txt_in.time_text_embed.text_embedder",
8292
"double_blocks": "transformer_blocks",
83-
"individual_token_refiner.blocks": "token_refiner.refiner_blocks",
8493
"img_attn_q_norm": "attn.norm_q",
8594
"img_attn_k_norm": "attn.norm_k",
8695
"img_attn_proj": "attn.to_out.0",
@@ -102,14 +111,15 @@ def remap_single_transformer_blocks_(key, state_dict):
102111
"final_layer.linear": "proj_out",
103112
"fc1": "net.0.proj",
104113
"fc2": "net.2",
114+
"input_embedder": "proj_in",
105115
}
106116

107117
TRANSFORMER_SPECIAL_KEYS_REMAP = {
108-
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
109118
"img_attn_qkv": remap_img_attn_qkv_,
110119
"txt_attn_qkv": remap_txt_attn_qkv_,
111-
"self_attn_qkv": remap_self_attn_qkv_,
112120
"single_blocks": remap_single_transformer_blocks_,
121+
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
122+
"individual_token_refiner.blocks": remap_token_refiner_blocks_,
113123
}
114124

115125
VAE_KEYS_RENAME_DICT = {}

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 75 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import math
16-
from functools import partial
17-
from typing import Dict, List, Optional, Tuple, Union
15+
from typing import Any, Dict, List, Optional, Tuple, Union
1816

1917
import torch
2018
import torch.nn as nn
@@ -24,7 +22,11 @@
2422
from ...utils import is_torch_version
2523
from ..attention import FeedForward
2624
from ..attention_processor import Attention, AttentionProcessor
27-
from ..embeddings import get_1d_rotary_pos_embed, get_timestep_embedding
25+
from ..embeddings import (
26+
CombinedTimestepGuidanceTextProjEmbeddings,
27+
CombinedTimestepTextProjEmbeddings,
28+
get_1d_rotary_pos_embed,
29+
)
2830
from ..modeling_outputs import Transformer2DModelOutput
2931
from ..modeling_utils import ModelMixin
3032
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -123,19 +125,6 @@ def __call__(
123125
return hidden_states, encoder_hidden_states
124126

125127

126-
class MLPEmbedder(nn.Module):
127-
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
128-
129-
def __init__(self, in_dim: int, hidden_dim: int):
130-
super().__init__()
131-
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
132-
self.silu = nn.SiLU()
133-
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
134-
135-
def forward(self, x: torch.Tensor) -> torch.Tensor:
136-
return self.out_layer(self.silu(self.in_layer(x)))
137-
138-
139128
class PatchEmbed(nn.Module):
140129
def __init__(
141130
self,
@@ -154,49 +143,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
154143
return hidden_states
155144

156145

157-
class TextProjection(nn.Module):
158-
def __init__(self, in_channels, hidden_size, act_layer):
146+
class HunyuanVideoAdaNorm(nn.Module):
147+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
159148
super().__init__()
160-
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True)
161-
self.act_1 = act_layer()
162-
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
163-
164-
def forward(self, caption):
165-
hidden_states = self.linear_1(caption)
166-
hidden_states = self.act_1(hidden_states)
167-
hidden_states = self.linear_2(hidden_states)
168-
return hidden_states
169-
170149

171-
class TimestepEmbedder(nn.Module):
172-
"""
173-
Embeds scalar timesteps into vector representations.
174-
"""
175-
176-
def __init__(
177-
self,
178-
hidden_size,
179-
act_layer,
180-
frequency_embedding_size=256,
181-
max_period=10000,
182-
out_size=None,
183-
):
184-
super().__init__()
185-
self.frequency_embedding_size = frequency_embedding_size
186-
self.max_period = max_period
187-
if out_size is None:
188-
out_size = hidden_size
189-
190-
self.mlp = nn.Sequential(
191-
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
192-
act_layer(),
193-
nn.Linear(hidden_size, out_size, bias=True),
194-
)
150+
out_features = out_features or 2 * in_features
151+
self.linear = nn.Linear(in_features, out_features)
152+
self.nonlinearity = nn.SiLU()
195153

196-
def forward(self, t):
197-
t_freq = get_timestep_embedding(t, self.frequency_embedding_size, flip_sin_to_cos=True, max_period=self.max_period, downscale_freq_shift=0).type(self.mlp[0].weight.dtype)
198-
t_emb = self.mlp(t_freq)
199-
return t_emb
154+
def forward(
155+
self, temb: torch.Tensor
156+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
157+
temb = self.linear(self.nonlinearity(temb))
158+
gate_msa, gate_mlp = temb.chunk(2, dim=1)
159+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
160+
return gate_msa, gate_mlp
200161

201162

202163
class IndividualTokenRefinerBlock(nn.Module):
@@ -224,29 +185,27 @@ def __init__(
224185
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
225186
self.mlp = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="silu", dropout=mlp_drop_rate)
226187

227-
self.adaLN_modulation = nn.Sequential(
228-
nn.SiLU(),
229-
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
230-
)
188+
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
231189

232190
def forward(
233191
self,
234192
hidden_states: torch.Tensor,
235193
temb: torch.Tensor,
236194
attention_mask: Optional[torch.Tensor] = None,
237195
) -> torch.Tensor:
238-
gate_msa, gate_mlp = self.adaLN_modulation(temb).chunk(2, dim=1)
239-
240196
norm_hidden_states = self.norm1(hidden_states)
241197

242198
attn_output = self.attn(
243199
hidden_states=norm_hidden_states,
244200
encoder_hidden_states=None,
245201
attention_mask=attention_mask,
246202
)
247-
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
248203

249-
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) * gate_mlp.unsqueeze(1)
204+
gate_msa, gate_mlp = self.norm_out(temb)
205+
hidden_states = hidden_states + attn_output * gate_msa
206+
207+
ff_output = self.mlp(self.norm2(hidden_states))
208+
hidden_states = hidden_states + ff_output * gate_mlp
250209

251210
return hidden_states
252211

@@ -313,10 +272,10 @@ def __init__(
313272

314273
hidden_size = num_attention_heads * attention_head_dim
315274

316-
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True)
317-
self.time_embed = TimestepEmbedder(hidden_size, nn.SiLU)
318-
self.context_embed = TextProjection(in_channels, hidden_size, nn.SiLU)
319-
275+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
276+
embedding_dim=hidden_size, pooled_projection_dim=in_channels
277+
)
278+
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
320279
self.token_refiner = IndividualTokenRefiner(
321280
num_attention_heads=num_attention_heads,
322281
attention_head_dim=attention_head_dim,
@@ -332,21 +291,17 @@ def forward(
332291
timestep: torch.LongTensor,
333292
attention_mask: Optional[torch.LongTensor] = None,
334293
) -> torch.Tensor:
335-
original_dtype = hidden_states.dtype
336-
temb = self.time_embed(timestep)
337-
338294
if attention_mask is None:
339295
pooled_projections = hidden_states.mean(dim=1)
340296
else:
297+
original_dtype = hidden_states.dtype
341298
mask_float = attention_mask.float().unsqueeze(-1)
342299
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
343300
pooled_projections = pooled_projections.to(original_dtype)
344301

345-
pooled_projections = self.context_embed(pooled_projections)
346-
emb = temb + pooled_projections
347-
348-
hidden_states = self.input_embedder(hidden_states)
349-
hidden_states = self.token_refiner(hidden_states, emb, attention_mask)
302+
temb = self.time_text_embed(timestep, pooled_projections)
303+
hidden_states = self.proj_in(hidden_states)
304+
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
350305

351306
return hidden_states
352307

@@ -561,14 +516,7 @@ def __init__(
561516
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
562517
)
563518

564-
# time modulation
565-
self.time_in = TimestepEmbedder(inner_dim, nn.SiLU)
566-
567-
# text modulation
568-
self.vector_in = MLPEmbedder(text_embed_dim_2, inner_dim)
569-
570-
# guidance modulation
571-
self.guidance_in = TimestepEmbedder(inner_dim, nn.SiLU)
519+
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, text_embed_dim_2)
572520

573521
# 3. RoPE
574522
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_dim_list, rope_theta)
@@ -679,30 +627,55 @@ def forward(
679627

680628
image_rotary_emb = self.rope(hidden_states)
681629

682-
temb = self.time_in(timestep)
683-
temb = temb + self.vector_in(encoder_hidden_states_2)
684-
temb = temb + self.guidance_in(guidance)
630+
temb = self.time_text_embed(timestep, guidance, encoder_hidden_states_2)
685631

686632
# Embed image and text.
687633
hidden_states = self.img_in(hidden_states)
688634
encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask)
689635

690-
use_reentrant = is_torch_version(">=", "1.11.0")
691-
block_forward = (
692-
partial(torch.utils.checkpoint.checkpoint, use_reentrant=use_reentrant)
693-
if torch.is_grad_enabled() and self.gradient_checkpointing
694-
else lambda x: x
695-
)
636+
if torch.is_grad_enabled() and self.gradient_checkpointing:
696637

697-
for _, block in enumerate(self.transformer_blocks):
698-
hidden_states, encoder_hidden_states = block_forward(block)(
699-
hidden_states, encoder_hidden_states, temb, image_rotary_emb
700-
)
638+
def create_custom_forward(module, return_dict=None):
639+
def custom_forward(*inputs):
640+
if return_dict is not None:
641+
return module(*inputs, return_dict=return_dict)
642+
else:
643+
return module(*inputs)
701644

702-
for block in self.single_transformer_blocks:
703-
hidden_states, encoder_hidden_states = block_forward(block)(
704-
hidden_states, encoder_hidden_states, temb, image_rotary_emb
705-
)
645+
return custom_forward
646+
647+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
648+
649+
for block in self.transformer_blocks:
650+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
651+
create_custom_forward(block),
652+
hidden_states,
653+
encoder_hidden_states,
654+
temb,
655+
image_rotary_emb,
656+
**ckpt_kwargs,
657+
)
658+
659+
for block in self.single_transformer_blocks:
660+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
661+
create_custom_forward(block),
662+
hidden_states,
663+
encoder_hidden_states,
664+
temb,
665+
image_rotary_emb,
666+
**ckpt_kwargs,
667+
)
668+
669+
else:
670+
for block in self.transformer_blocks:
671+
hidden_states, encoder_hidden_states = block(
672+
hidden_states, encoder_hidden_states, temb, image_rotary_emb
673+
)
674+
675+
for block in self.single_transformer_blocks:
676+
hidden_states, encoder_hidden_states = block(
677+
hidden_states, encoder_hidden_states, temb, image_rotary_emb
678+
)
706679

707680
hidden_states = self.norm_out(hidden_states, temb)
708681
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)