Skip to content

Commit 1e80f7c

Browse files
committed
make style
1 parent cb4fc37 commit 1e80f7c

File tree

3 files changed

+50
-27
lines changed

3 files changed

+50
-27
lines changed

scripts/convert_hunyuan_video_to_diffusers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def remap_single_transformer_blocks_(key, state_dict):
2626
state_dict[f"{new_key}.attn.to_k.weight"] = k
2727
state_dict[f"{new_key}.attn.to_v.weight"] = v
2828
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
29-
29+
3030
elif "linear1.bias" in key:
3131
linear1_bias = state_dict.pop(key)
3232
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
@@ -36,7 +36,7 @@ def remap_single_transformer_blocks_(key, state_dict):
3636
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
3737
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
3838
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
39-
39+
4040
else:
4141
new_key = key.replace("single_blocks", "single_transformer_blocks")
4242
new_key = new_key.replace("linear2", "proj_out")

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,18 @@ def __init__(
10431043
in_channels: int = 3,
10441044
out_channels: int = 3,
10451045
latent_channels: int = 16,
1046-
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D", "DownEncoderBlockCausal3D",),
1047-
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D", "UpDecoderBlockCausal3D",),
1046+
down_block_types: Tuple[str, ...] = (
1047+
"DownEncoderBlockCausal3D",
1048+
"DownEncoderBlockCausal3D",
1049+
"DownEncoderBlockCausal3D",
1050+
"DownEncoderBlockCausal3D",
1051+
),
1052+
up_block_types: Tuple[str, ...] = (
1053+
"UpDecoderBlockCausal3D",
1054+
"UpDecoderBlockCausal3D",
1055+
"UpDecoderBlockCausal3D",
1056+
"UpDecoderBlockCausal3D",
1057+
),
10481058
block_out_channels: Tuple[int] = (128, 256, 512, 512),
10491059
layers_per_block: int = 2,
10501060
act_fn: str = "silu",

src/diffusers/models/transformers/transformer_hunyuan_video.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def apply_rotary_emb(
156156
class HunyuanVideoAttnProcessor2_0:
157157
def __init__(self):
158158
if not hasattr(F, "scaled_dot_product_attention"):
159-
raise ImportError("HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
159+
raise ImportError(
160+
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
161+
)
160162

161163
def __call__(
162164
self,
@@ -186,18 +188,24 @@ def __call__(
186188
from ..embeddings import apply_rotary_emb
187189

188190
if attn.add_q_proj is None and encoder_hidden_states is not None:
189-
query = torch.cat([
190-
apply_rotary_emb(query[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb),
191-
query[:, :, -encoder_hidden_states.shape[1]:],
192-
], dim=2)
193-
key = torch.cat([
194-
apply_rotary_emb(key[:, :, :-encoder_hidden_states.shape[1]], image_rotary_emb),
195-
key[:, :, -encoder_hidden_states.shape[1]:],
196-
], dim=2)
191+
query = torch.cat(
192+
[
193+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
194+
query[:, :, -encoder_hidden_states.shape[1] :],
195+
],
196+
dim=2,
197+
)
198+
key = torch.cat(
199+
[
200+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
201+
key[:, :, -encoder_hidden_states.shape[1] :],
202+
],
203+
dim=2,
204+
)
197205
else:
198206
query = apply_rotary_emb(query, image_rotary_emb)
199207
key = apply_rotary_emb(key, image_rotary_emb)
200-
208+
201209
if attn.add_q_proj is not None and encoder_hidden_states is not None:
202210
encoder_query = attn.add_q_proj(encoder_hidden_states)
203211
encoder_key = attn.add_k_proj(encoder_hidden_states)
@@ -224,8 +232,8 @@ def __call__(
224232

225233
if encoder_hidden_states is not None:
226234
hidden_states, encoder_hidden_states = (
227-
hidden_states[:, :-encoder_hidden_states.shape[1]],
228-
hidden_states[:, -encoder_hidden_states.shape[1]:],
235+
hidden_states[:, : -encoder_hidden_states.shape[1]],
236+
hidden_states[:, -encoder_hidden_states.shape[1] :],
229237
)
230238

231239
if not attn.pre_only:
@@ -513,7 +521,7 @@ def __init__(
513521

514522
hidden_size = num_attention_heads * attention_head_dim
515523
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
516-
524+
517525
self.hidden_size = hidden_size
518526
self.heads_num = num_attention_heads
519527
self.mlp_hidden_dim = mlp_hidden_dim
@@ -546,14 +554,17 @@ def forward(
546554
) -> torch.Tensor:
547555
text_seq_length = encoder_hidden_states.shape[1]
548556
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
549-
557+
550558
residual = hidden_states
551-
559+
552560
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
553561
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
554-
555-
norm_hidden_states, norm_encoder_hidden_states = norm_hidden_states[:, :-text_seq_length, :], norm_hidden_states[:, -text_seq_length:, :]
556-
562+
563+
norm_hidden_states, norm_encoder_hidden_states = (
564+
norm_hidden_states[:, :-text_seq_length, :],
565+
norm_hidden_states[:, -text_seq_length:, :],
566+
)
567+
557568
# qkv, mlp = torch.split(self.linear1(norm_hidden_states), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
558569
# q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
559570

@@ -584,7 +595,10 @@ def forward(
584595
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
585596
hidden_states = hidden_states + residual
586597

587-
hidden_states, encoder_hidden_states = hidden_states[:, :-text_seq_length, :], hidden_states[:, -text_seq_length:, :]
598+
hidden_states, encoder_hidden_states = (
599+
hidden_states[:, :-text_seq_length, :],
600+
hidden_states[:, -text_seq_length:, :],
601+
)
588602
return hidden_states, encoder_hidden_states
589603

590604

@@ -631,7 +645,7 @@ def forward(
631645
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
632646
encoder_hidden_states, emb=temb
633647
)
634-
648+
635649
img_qkv = self.img_attn_qkv(norm_hidden_states)
636650
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
637651
# Apply QK-Norm if needed
@@ -645,7 +659,7 @@ def forward(
645659
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
646660
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
647661
img_q, img_k = img_qq, img_kk
648-
662+
649663
txt_qkv = self.txt_attn_qkv(norm_encoder_hidden_states)
650664
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
651665
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
@@ -830,7 +844,6 @@ def forward(
830844
encoder_hidden_states = self.txt_in(encoder_hidden_states, timestep, encoder_attention_mask)
831845

832846
txt_seq_len = encoder_hidden_states.shape[1]
833-
img_seq_len = hidden_states.shape[1]
834847

835848
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
836849
for _, block in enumerate(self.transformer_blocks):
@@ -842,7 +855,7 @@ def forward(
842855
]
843856

844857
hidden_states, encoder_hidden_states = block(*double_block_args)
845-
858+
846859
for block in self.single_transformer_blocks:
847860
single_block_args = [
848861
hidden_states,

0 commit comments

Comments
 (0)