Skip to content

Commit 9c19bda

Browse files
committed
update middle, test up_block
1 parent 2e12d1b commit 9c19bda

File tree

1 file changed

+199
-26
lines changed

1 file changed

+199
-26
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 199 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
from typing import Optional, Tuple, List, Union
16-
from einops import rearrange
1716

1817
import numpy as np
1918
import torch
@@ -93,6 +92,7 @@ def forward(self, x, cache_x=None):
9392
return super().forward(x)
9493

9594

95+
# TODO: not used yet, will not affect the state dict so can be refactored in follow up PR
9696
class WanCausalConv3dYiYi(nn.Conv3d):
9797
r"""
9898
A custom 3D causal convolution layer with feature caching support.
@@ -401,9 +401,7 @@ def __init__(self, dim):
401401
self.norm = WanRMS_norm(dim)
402402
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
403403
self.proj = nn.Conv2d(dim, dim, 1)
404-
405-
# zero out the last layer params
406-
nn.init.zeros_(self.proj.weight)
404+
407405

408406
def forward(self, x):
409407
identity = x
@@ -529,11 +527,7 @@ def __init__(
529527
scale /= 2.0
530528

531529
# middle blocks
532-
self.middle = nn.Sequential(
533-
WanResidualBlock(out_dim, out_dim, dropout),
534-
WanAttentionBlock(out_dim),
535-
WanResidualBlock(out_dim, out_dim, dropout)
536-
)
530+
self.middle = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
537531

538532
# output blocks
539533
self.head = nn.Sequential(
@@ -563,11 +557,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
563557
x = layer(x)
564558

565559
## middle
566-
for layer in self.middle:
567-
if isinstance(layer, WanResidualBlock) and feat_cache is not None:
568-
x = layer(x, feat_cache, feat_idx)
569-
else:
570-
x = layer(x)
560+
x = self.middle(x, feat_cache, feat_idx)
571561

572562
## head
573563
for layer in self.head:
@@ -585,6 +575,78 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
585575
return x
586576

587577

578+
class WanUpBlock(nn.Module):
579+
"""
580+
A block that handles upsampling for the WanVAE decoder.
581+
582+
Args:
583+
in_dim (int): Input dimension
584+
out_dim (int): Output dimension
585+
num_res_blocks (int): Number of residual blocks
586+
dropout (float): Dropout rate
587+
use_attention (bool): Whether to use attention
588+
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
589+
non_linearity (str): Type of non-linearity to use
590+
"""
591+
def __init__(
592+
self,
593+
in_dim: int,
594+
out_dim: int,
595+
num_res_blocks: int,
596+
dropout: float = 0.0,
597+
use_attention: bool = False,
598+
upsample_mode: Optional[str] = None,
599+
non_linearity: str = "silu",
600+
):
601+
super().__init__()
602+
self.in_dim = in_dim
603+
self.out_dim = out_dim
604+
605+
# Create layers list
606+
resnets = []
607+
attentions = []
608+
# Add residual blocks and attention if needed
609+
current_dim = in_dim
610+
for _ in range(num_res_blocks + 1):
611+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
612+
if use_attention:
613+
attentions.append(WanAttentionBlock(out_dim))
614+
current_dim = out_dim
615+
616+
self.resnets = nn.ModuleList(resnets)
617+
self.attentions = nn.ModuleList(attentions)
618+
619+
# Add upsampling layer if needed
620+
self.upsamplers = None
621+
if upsample_mode is not None:
622+
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
623+
624+
self.gradient_checkpointing = False
625+
626+
def forward(self, x, feat_cache=None, feat_idx=[0]):
627+
"""
628+
Forward pass through the upsampling block.
629+
630+
Args:
631+
x (torch.Tensor): Input tensor
632+
feat_cache (list, optional): Feature cache for causal convolutions
633+
feat_idx (list, optional): Feature index for cache management
634+
635+
Returns:
636+
torch.Tensor: Output tensor
637+
"""
638+
for resnet, attention in zip(self.resnets, self.attentions):
639+
if feat_cache is not None:
640+
x = resnet(x, feat_cache, feat_idx)
641+
else:
642+
x = resnet(x)
643+
if attention is not None:
644+
x = attention(x)
645+
if self.upsamplers is not None:
646+
x = self.upsamplers[0](x)
647+
return x
648+
649+
588650
class WanDecoder3d(nn.Module):
589651
r"""
590652
A 3D decoder module.
@@ -628,12 +690,7 @@ def __init__(
628690
self.conv1 = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
629691

630692
# middle blocks
631-
self.middle = nn.Sequential(
632-
WanResidualBlock(dims[0], dims[0], dropout),
633-
WanAttentionBlock(dims[0]),
634-
WanResidualBlock(dims[0], dims[0], dropout)
635-
)
636-
693+
self.middle = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
637694
# upsample blocks
638695
upsamples = []
639696
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
@@ -674,12 +731,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
674731
else:
675732
x = self.conv1(x)
676733

677-
## middle
678-
for layer in self.middle:
679-
if isinstance(layer, WanResidualBlock) and feat_cache is not None:
680-
x = layer(x, feat_cache, feat_idx)
681-
else:
682-
x = layer(x)
734+
x = self.middle(x, feat_cache, feat_idx)
683735

684736
## upsamples
685737
for layer in self.upsamples:
@@ -704,6 +756,127 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
704756
return x
705757

706758

759+
class WanDecoder3dYiYi(nn.Module):
760+
r"""
761+
A 3D decoder module.
762+
763+
Args:
764+
dim (int): The base number of channels in the first layer.
765+
z_dim (int): The dimensionality of the latent space.
766+
dim_mult (list of int): Multipliers for the number of channels in each block.
767+
num_res_blocks (int): Number of residual blocks in each block.
768+
attn_scales (list of float): Scales at which to apply attention mechanisms.
769+
temperal_upsample (list of bool): Whether to upsample temporally in each block.
770+
dropout (float): Dropout rate for the dropout layers.
771+
non_linearity (str): Type of non-linearity to use.
772+
"""
773+
def __init__(
774+
self,
775+
dim=128,
776+
z_dim=4,
777+
dim_mult=[1, 2, 4, 4],
778+
num_res_blocks=2,
779+
attn_scales=[],
780+
temperal_upsample=[False, True, True],
781+
dropout=0.0,
782+
non_linearity: str = "silu",
783+
):
784+
super().__init__()
785+
self.dim = dim
786+
self.z_dim = z_dim
787+
self.dim_mult = dim_mult
788+
self.num_res_blocks = num_res_blocks
789+
self.attn_scales = attn_scales
790+
self.temperal_upsample = temperal_upsample
791+
792+
self.nonlinearity = get_activation(non_linearity)
793+
794+
# dimensions
795+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
796+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
797+
798+
# init block
799+
self.conv1 = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
800+
801+
# middle blocks
802+
self.middle = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
803+
804+
805+
# upsample blocks
806+
upsamples = nn.ModuleList([])
807+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
808+
# residual (+attention) blocks
809+
if i > 0:
810+
in_dim = in_dim // 2
811+
812+
# Determine if we need attention and upsampling
813+
use_attention = scale in attn_scales
814+
upsample_mode = None
815+
if i != len(dim_mult) - 1:
816+
upsample_mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
817+
818+
# Create and add the upsampling block
819+
up_block = WanUpBlock(
820+
in_dim=in_dim,
821+
out_dim=out_dim,
822+
num_res_blocks=num_res_blocks,
823+
dropout=dropout,
824+
use_attention=use_attention,
825+
upsample_mode=upsample_mode,
826+
non_linearity=non_linearity,
827+
)
828+
upsamples.append(up_block)
829+
830+
# Update scale for next iteration
831+
if upsample_mode is not None:
832+
scale *= 2.0
833+
834+
self.upsamples = upsamples
835+
836+
# output blocks
837+
self.head = nn.Sequential(
838+
WanRMS_norm(out_dim, images=False),
839+
self.nonlinearity,
840+
WanCausalConv3d(out_dim, 3, 3, padding=1)
841+
)
842+
843+
def forward(self, x, feat_cache=None, feat_idx=[0]):
844+
## conv1
845+
if feat_cache is not None:
846+
idx = feat_idx[0]
847+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
848+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
849+
# cache last frame of last two chunk
850+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
851+
x = self.conv1(x, feat_cache[idx])
852+
feat_cache[idx] = cache_x
853+
feat_idx[0] += 1
854+
else:
855+
x = self.conv1(x)
856+
857+
## middle
858+
x = self.middle(x, feat_cache, feat_idx)
859+
860+
## upsamples
861+
for up_block in self.upsamples:
862+
x = up_block(x, feat_cache, feat_idx)
863+
864+
## head
865+
for layer in self.head:
866+
if isinstance(layer, WanCausalConv3d) and feat_cache is not None:
867+
idx = feat_idx[0]
868+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
869+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
870+
# cache last frame of last two chunk
871+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
872+
x = layer(x, feat_cache[idx])
873+
feat_cache[idx] = cache_x
874+
feat_idx[0] += 1
875+
else:
876+
x = layer(x)
877+
return x
878+
879+
707880
class AutoencoderKLWan(ModelMixin, ConfigMixin):
708881
r"""
709882
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.

0 commit comments

Comments
 (0)