Skip to content

Commit 2da1feb

Browse files
committed
up up
1 parent 9c19bda commit 2da1feb

File tree

1 file changed

+9
-124
lines changed

1 file changed

+9
-124
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 9 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def __init__(self, dim: int, mode: str) -> None:
271271
def forward(self, x, feat_cache=None, feat_idx=[0]):
272272
b, c, t, h, w = x.size()
273273
if self.mode == 'upsample3d':
274-
if feat_cache is not None:
274+
if feat_cache is not None:
275275
idx = feat_idx[0]
276276
if feat_cache[idx] is None:
277277
feat_cache[idx] = 'Rep'
@@ -403,7 +403,7 @@ def __init__(self, dim):
403403
self.proj = nn.Conv2d(dim, dim, 1)
404404

405405

406-
def forward(self, x):
406+
def forward(self, x):
407407
identity = x
408408
batch_size, channels, time, height, width = x.size()
409409

@@ -427,7 +427,7 @@ def forward(self, x):
427427
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
428428
x = x.view(batch_size, time, channels, height, width)
429429
x = x.permute(0, 2, 1, 3, 4)
430-
430+
431431
return x + identity
432432

433433

@@ -584,7 +584,6 @@ class WanUpBlock(nn.Module):
584584
out_dim (int): Output dimension
585585
num_res_blocks (int): Number of residual blocks
586586
dropout (float): Dropout rate
587-
use_attention (bool): Whether to use attention
588587
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
589588
non_linearity (str): Type of non-linearity to use
590589
"""
@@ -594,7 +593,6 @@ def __init__(
594593
out_dim: int,
595594
num_res_blocks: int,
596595
dropout: float = 0.0,
597-
use_attention: bool = False,
598596
upsample_mode: Optional[str] = None,
599597
non_linearity: str = "silu",
600598
):
@@ -604,17 +602,13 @@ def __init__(
604602

605603
# Create layers list
606604
resnets = []
607-
attentions = []
608605
# Add residual blocks and attention if needed
609606
current_dim = in_dim
610607
for _ in range(num_res_blocks + 1):
611608
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
612-
if use_attention:
613-
attentions.append(WanAttentionBlock(out_dim))
614609
current_dim = out_dim
615610

616611
self.resnets = nn.ModuleList(resnets)
617-
self.attentions = nn.ModuleList(attentions)
618612

619613
# Add upsampling layer if needed
620614
self.upsamplers = None
@@ -635,128 +629,21 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
635629
Returns:
636630
torch.Tensor: Output tensor
637631
"""
638-
for resnet, attention in zip(self.resnets, self.attentions):
632+
for resnet in self.resnets:
639633
if feat_cache is not None:
640634
x = resnet(x, feat_cache, feat_idx)
641635
else:
642636
x = resnet(x)
643-
if attention is not None:
644-
x = attention(x)
645-
if self.upsamplers is not None:
646-
x = self.upsamplers[0](x)
647-
return x
648-
649-
650-
class WanDecoder3d(nn.Module):
651-
r"""
652-
A 3D decoder module.
653-
654-
Args:
655-
dim (int): The base number of channels in the first layer.
656-
z_dim (int): The dimensionality of the latent space.
657-
dim_mult (list of int): Multipliers for the number of channels in each block.
658-
num_res_blocks (int): Number of residual blocks in each block.
659-
attn_scales (list of float): Scales at which to apply attention mechanisms.
660-
temperal_upsample (list of bool): Whether to upsample temporally in each block.
661-
dropout (float): Dropout rate for the dropout layers.
662-
non_linearity (str): Type of non-linearity to use.
663-
"""
664-
def __init__(
665-
self,
666-
dim=128,
667-
z_dim=4,
668-
dim_mult=[1, 2, 4, 4],
669-
num_res_blocks=2,
670-
attn_scales=[],
671-
temperal_upsample=[False, True, True],
672-
dropout=0.0,
673-
non_linearity: str = "silu",
674-
):
675-
super().__init__()
676-
self.dim = dim
677-
self.z_dim = z_dim
678-
self.dim_mult = dim_mult
679-
self.num_res_blocks = num_res_blocks
680-
self.attn_scales = attn_scales
681-
self.temperal_upsample = temperal_upsample
682-
683-
self.nonlinearity = get_activation(non_linearity)
684-
685-
# dimensions
686-
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
687-
scale = 1.0 / 2 ** (len(dim_mult) - 2)
688-
689-
# init block
690-
self.conv1 = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
691-
692-
# middle blocks
693-
self.middle = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
694-
# upsample blocks
695-
upsamples = []
696-
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
697-
# residual (+attention) blocks
698-
if i > 0:
699-
in_dim = in_dim //2
700-
for _ in range(num_res_blocks + 1):
701-
upsamples.append(WanResidualBlock(in_dim, out_dim, dropout))
702-
if scale in attn_scales:
703-
upsamples.append(WanAttentionBlock(out_dim))
704-
in_dim = out_dim
705-
706-
# upsample block
707-
if i != len(dim_mult) - 1:
708-
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
709-
upsamples.append(WanResample(out_dim, mode=mode))
710-
scale *= 2.0
711-
self.upsamples = nn.Sequential(*upsamples)
712637

713-
# output blocks
714-
self.head = nn.Sequential(
715-
WanRMS_norm(out_dim, images=False),
716-
self.nonlinearity,
717-
WanCausalConv3d(out_dim, 3, 3, padding=1)
718-
)
719-
720-
def forward(self, x, feat_cache=None, feat_idx=[0]):
721-
## conv1
722-
if feat_cache is not None:
723-
idx = feat_idx[0]
724-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
725-
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
726-
# cache last frame of last two chunk
727-
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
728-
x = self.conv1(x, feat_cache[idx])
729-
feat_cache[idx] = cache_x
730-
feat_idx[0] += 1
731-
else:
732-
x = self.conv1(x)
733-
734-
x = self.middle(x, feat_cache, feat_idx)
735-
736-
## upsamples
737-
for layer in self.upsamples:
638+
if self.upsamplers is not None:
738639
if feat_cache is not None:
739-
x = layer(x, feat_cache, feat_idx)
740-
else:
741-
x = layer(x)
742-
743-
## head
744-
for layer in self.head:
745-
if isinstance(layer, WanCausalConv3d) and feat_cache is not None:
746-
idx = feat_idx[0]
747-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
748-
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
749-
# cache last frame of last two chunk
750-
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
751-
x = layer(x, feat_cache[idx])
752-
feat_cache[idx] = cache_x
753-
feat_idx[0] += 1
640+
x = self.upsamplers[0](x, feat_cache, feat_idx)
754641
else:
755-
x = layer(x)
642+
x = self.upsamplers[0](x)
756643
return x
757644

758645

759-
class WanDecoder3dYiYi(nn.Module):
646+
class WanDecoder3d(nn.Module):
760647
r"""
761648
A 3D decoder module.
762649
@@ -809,8 +696,7 @@ def __init__(
809696
if i > 0:
810697
in_dim = in_dim // 2
811698

812-
# Determine if we need attention and upsampling
813-
use_attention = scale in attn_scales
699+
# Determine if we need upsampling
814700
upsample_mode = None
815701
if i != len(dim_mult) - 1:
816702
upsample_mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
@@ -821,7 +707,6 @@ def __init__(
821707
out_dim=out_dim,
822708
num_res_blocks=num_res_blocks,
823709
dropout=dropout,
824-
use_attention=use_attention,
825710
upsample_mode=upsample_mode,
826711
non_linearity=non_linearity,
827712
)

0 commit comments

Comments
 (0)