Skip to content

Commit 27ce75b

Browse files
committed
add 5b t2v
1 parent bf2c6e0 commit 27ce75b

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
236236
super().__init__()
237237
self.dim = dim
238238
self.mode = mode
239-
239+
240240
# default to dim //2
241241
if upsample_out_dim is None:
242242
upsample_out_dim = dim // 2
@@ -524,7 +524,7 @@ class WanEncoder3d(nn.Module):
524524

525525
def __init__(
526526
self,
527-
in_channels: int = 3,
527+
in_channels: int = 3,
528528
dim=128,
529529
z_dim=4,
530530
dim_mult=[1, 2, 4, 4],
@@ -558,10 +558,10 @@ def __init__(
558558
if is_residual:
559559
self.down_blocks.append(
560560
WanResidualDownBlock(
561-
in_dim,
562-
out_dim,
563-
dropout,
564-
num_res_blocks,
561+
in_dim,
562+
out_dim,
563+
dropout,
564+
num_res_blocks,
565565
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
566566
down_flag=i != len(dim_mult) - 1,
567567
)
@@ -708,10 +708,10 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
708708
x = self.upsampler(x, feat_cache, feat_idx)
709709
else:
710710
x = self.upsampler(x)
711-
711+
712712
if self.avg_shortcut is not None:
713713
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
714-
714+
715715
return x
716716

717717
class WanUpBlock(nn.Module):
@@ -912,10 +912,9 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
912912
return x
913913

914914

915-
# YiYi TODO: refactor this
916-
from einops import rearrange
917-
918915
def patchify(x, patch_size):
916+
# YiYi TODO: refactor this
917+
from einops import rearrange
919918
if patch_size == 1:
920919
return x
921920
if x.dim() == 4:
@@ -935,6 +934,8 @@ def patchify(x, patch_size):
935934

936935

937936
def unpatchify(x, patch_size):
937+
# YiYi TODO: refactor this
938+
from einops import rearrange
938939
if patch_size == 1:
939940
return x
940941

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(
171171
encoder_hidden_states: torch.Tensor,
172172
encoder_hidden_states_image: Optional[torch.Tensor] = None,
173173
timestep_seq_len: Optional[int] = None,
174-
):
174+
):
175175
timestep = self.timesteps_proj(timestep)
176176
if timestep_seq_len is not None:
177177
timestep = timestep.unflatten(0, (1, timestep_seq_len))
@@ -518,7 +518,7 @@ def forward(
518518
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
519519

520520
# 5. Output norm, projection & unpatchify
521-
if temb.ndim ==3:
521+
if temb.ndim ==3:
522522
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
523523
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
524524
shift = shift.squeeze(2)

src/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def check_inputs(
318318
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
319319
):
320320
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
321-
321+
322322
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
323323
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
324324

0 commit comments

Comments
 (0)