1313# limitations under the License. 
1414
1515from  typing  import  Optional , Tuple , List , Union 
16- from  einops  import  rearrange 
1716
1817import  numpy  as  np 
1918import  torch 
@@ -93,6 +92,7 @@ def forward(self, x, cache_x=None):
9392        return  super ().forward (x )
9493
9594
95+ # TODO: not used yet, will not affect the state dict so can be refactored in follow up PR 
9696class  WanCausalConv3dYiYi (nn .Conv3d ):
9797    r""" 
9898    A custom 3D causal convolution layer with feature caching support. 
@@ -401,9 +401,7 @@ def __init__(self, dim):
401401        self .norm  =  WanRMS_norm (dim )
402402        self .to_qkv  =  nn .Conv2d (dim , dim  *  3 , 1 )
403403        self .proj  =  nn .Conv2d (dim , dim , 1 )
404-         
405-         # zero out the last layer params 
406-         nn .init .zeros_ (self .proj .weight )
404+ 
407405
408406    def  forward (self , x ): 
409407        identity  =  x 
@@ -529,11 +527,7 @@ def __init__(
529527                scale  /=  2.0 
530528
531529        # middle blocks 
532-         self .middle  =  nn .Sequential (
533-             WanResidualBlock (out_dim , out_dim , dropout ),
534-             WanAttentionBlock (out_dim ),
535-             WanResidualBlock (out_dim , out_dim , dropout )
536-         )
530+         self .middle  =  WanMidBlock (out_dim , dropout , non_linearity , num_layers = 1 )
537531
538532        # output blocks 
539533        self .head  =  nn .Sequential (
@@ -563,11 +557,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
563557                x  =  layer (x ) 
564558
565559        ## middle  
566-         for  layer  in  self .middle :
567-             if  isinstance (layer , WanResidualBlock ) and  feat_cache  is  not None :
568-                 x  =  layer (x , feat_cache , feat_idx )  
569-             else :
570-                 x  =  layer (x ) 
560+         x  =  self .middle (x , feat_cache , feat_idx )
571561
572562        ## head  
573563        for  layer  in  self .head :
@@ -585,6 +575,78 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
585575        return  x 
586576
587577
578+ class  WanUpBlock (nn .Module ):
579+     """ 
580+     A block that handles upsampling for the WanVAE decoder. 
581+      
582+     Args: 
583+         in_dim (int): Input dimension 
584+         out_dim (int): Output dimension 
585+         num_res_blocks (int): Number of residual blocks 
586+         dropout (float): Dropout rate 
587+         use_attention (bool): Whether to use attention 
588+         upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') 
589+         non_linearity (str): Type of non-linearity to use 
590+     """ 
591+     def  __init__ (
592+         self ,
593+         in_dim : int ,
594+         out_dim : int ,
595+         num_res_blocks : int ,
596+         dropout : float  =  0.0 ,
597+         use_attention : bool  =  False ,
598+         upsample_mode : Optional [str ] =  None ,
599+         non_linearity : str  =  "silu" ,
600+     ):
601+         super ().__init__ ()
602+         self .in_dim  =  in_dim 
603+         self .out_dim  =  out_dim 
604+         
605+         # Create layers list 
606+         resnets  =  []
607+         attentions  =  []
608+         # Add residual blocks and attention if needed 
609+         current_dim  =  in_dim 
610+         for  _  in  range (num_res_blocks  +  1 ):
611+             resnets .append (WanResidualBlock (current_dim , out_dim , dropout , non_linearity ))
612+             if  use_attention :
613+                 attentions .append (WanAttentionBlock (out_dim ))
614+             current_dim  =  out_dim 
615+         
616+         self .resnets  =  nn .ModuleList (resnets )
617+         self .attentions  =  nn .ModuleList (attentions )
618+ 
619+         # Add upsampling layer if needed 
620+         self .upsamplers  =  None 
621+         if  upsample_mode  is  not None :
622+             self .upsamplers  =  nn .ModuleList ([WanResample (out_dim , mode = upsample_mode )])
623+         
624+         self .gradient_checkpointing  =  False 
625+     
626+     def  forward (self , x , feat_cache = None , feat_idx = [0 ]):
627+         """ 
628+         Forward pass through the upsampling block. 
629+          
630+         Args: 
631+             x (torch.Tensor): Input tensor 
632+             feat_cache (list, optional): Feature cache for causal convolutions 
633+             feat_idx (list, optional): Feature index for cache management 
634+              
635+         Returns: 
636+             torch.Tensor: Output tensor 
637+         """ 
638+         for  resnet , attention  in  zip (self .resnets , self .attentions ):
639+             if  feat_cache  is  not None :
640+                 x  =  resnet (x , feat_cache , feat_idx )
641+             else :
642+                 x  =  resnet (x )
643+             if  attention  is  not None :
644+                 x  =  attention (x )
645+         if  self .upsamplers  is  not None :
646+             x  =  self .upsamplers [0 ](x )
647+         return  x 
648+ 
649+ 
588650class  WanDecoder3d (nn .Module ):
589651    r""" 
590652    A 3D decoder module. 
@@ -628,12 +690,7 @@ def __init__(
628690        self .conv1  =  WanCausalConv3d (z_dim , dims [0 ], 3 , padding = 1 )
629691
630692        # middle blocks 
631-         self .middle  =  nn .Sequential (
632-             WanResidualBlock (dims [0 ], dims [0 ], dropout ),
633-             WanAttentionBlock (dims [0 ]),
634-             WanResidualBlock (dims [0 ], dims [0 ], dropout )
635-         )
636- 
693+         self .middle  =  WanMidBlock (dims [0 ], dropout , non_linearity , num_layers = 1 )
637694        # upsample blocks 
638695        upsamples  =  []
639696        for  i , (in_dim , out_dim ) in  enumerate (zip (dims [:- 1 ], dims [1 :])):
@@ -674,12 +731,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
674731        else :
675732            x  =  self .conv1 (x ) 
676733
677-         ## middle  
678-         for  layer  in  self .middle :
679-             if  isinstance (layer , WanResidualBlock ) and  feat_cache  is  not None :
680-                 x  =  layer (x , feat_cache , feat_idx )  
681-             else :
682-                 x  =  layer (x ) 
734+         x  =  self .middle (x , feat_cache , feat_idx )
683735
684736        ## upsamples  
685737        for  layer  in  self .upsamples :
@@ -704,6 +756,127 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
704756        return  x 
705757
706758
759+ class  WanDecoder3dYiYi (nn .Module ):
760+     r""" 
761+     A 3D decoder module. 
762+ 
763+     Args: 
764+         dim (int): The base number of channels in the first layer. 
765+         z_dim (int): The dimensionality of the latent space. 
766+         dim_mult (list of int): Multipliers for the number of channels in each block. 
767+         num_res_blocks (int): Number of residual blocks in each block. 
768+         attn_scales (list of float): Scales at which to apply attention mechanisms. 
769+         temperal_upsample (list of bool): Whether to upsample temporally in each block. 
770+         dropout (float): Dropout rate for the dropout layers. 
771+         non_linearity (str): Type of non-linearity to use. 
772+     """ 
773+     def  __init__ (
774+         self ,
775+         dim = 128 ,
776+         z_dim = 4 ,
777+         dim_mult = [1 , 2 , 4 , 4 ],
778+         num_res_blocks = 2 ,
779+         attn_scales = [],
780+         temperal_upsample = [False , True , True ],
781+         dropout = 0.0 ,
782+         non_linearity : str  =  "silu" ,
783+     ):
784+         super ().__init__ ()
785+         self .dim  =  dim 
786+         self .z_dim  =  z_dim 
787+         self .dim_mult  =  dim_mult 
788+         self .num_res_blocks  =  num_res_blocks 
789+         self .attn_scales  =  attn_scales 
790+         self .temperal_upsample  =  temperal_upsample 
791+ 
792+         self .nonlinearity  =  get_activation (non_linearity )
793+ 
794+         # dimensions 
795+         dims  =  [dim  *  u  for  u  in  [dim_mult [- 1 ]] +  dim_mult [::- 1 ]]
796+         scale  =  1.0  /  2  **  (len (dim_mult ) -  2 )
797+ 
798+         # init block 
799+         self .conv1  =  WanCausalConv3d (z_dim , dims [0 ], 3 , padding = 1 )
800+ 
801+         # middle blocks 
802+         self .middle  =  WanMidBlock (dims [0 ], dropout , non_linearity , num_layers = 1 )
803+ 
804+ 
805+         # upsample blocks 
806+         upsamples  =  nn .ModuleList ([])
807+         for  i , (in_dim , out_dim ) in  enumerate (zip (dims [:- 1 ], dims [1 :])):
808+             # residual (+attention) blocks 
809+             if  i  >  0 :
810+                 in_dim  =  in_dim  //  2 
811+             
812+             # Determine if we need attention and upsampling 
813+             use_attention  =  scale  in  attn_scales 
814+             upsample_mode  =  None 
815+             if  i  !=  len (dim_mult ) -  1 :
816+                 upsample_mode  =  'upsample3d'  if  temperal_upsample [i ] else  'upsample2d' 
817+             
818+             # Create and add the upsampling block 
819+             up_block  =  WanUpBlock (
820+                 in_dim = in_dim ,
821+                 out_dim = out_dim ,
822+                 num_res_blocks = num_res_blocks ,
823+                 dropout = dropout ,
824+                 use_attention = use_attention ,
825+                 upsample_mode = upsample_mode ,
826+                 non_linearity = non_linearity ,
827+             )
828+             upsamples .append (up_block )
829+             
830+             # Update scale for next iteration 
831+             if  upsample_mode  is  not None :
832+                 scale  *=  2.0 
833+ 
834+         self .upsamples  =  upsamples 
835+ 
836+         # output blocks 
837+         self .head  =  nn .Sequential (
838+             WanRMS_norm (out_dim , images = False ),
839+             self .nonlinearity ,
840+             WanCausalConv3d (out_dim , 3 , 3 , padding = 1 )
841+         )
842+     
843+     def  forward (self , x , feat_cache = None , feat_idx = [0 ]):
844+         ## conv1 
845+         if  feat_cache  is  not None :
846+             idx  =  feat_idx [0 ]
847+             cache_x  =  x [:, :, - CACHE_T :, :, :].clone ()
848+             if  cache_x .shape [2 ] <  2  and  feat_cache [idx ] is  not None :
849+                 # cache last frame of last two chunk 
850+                 cache_x  =  torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
851+             x  =  self .conv1 (x , feat_cache [idx ])
852+             feat_cache [idx ] =  cache_x 
853+             feat_idx [0 ] +=  1  
854+         else :
855+             x  =  self .conv1 (x ) 
856+         
857+         ## middle  
858+         x  =  self .middle (x , feat_cache , feat_idx )
859+ 
860+         ## upsamples  
861+         for  up_block  in  self .upsamples :
862+             x  =  up_block (x , feat_cache , feat_idx )
863+ 
864+         ## head  
865+         for  layer  in  self .head :
866+             if  isinstance (layer , WanCausalConv3d ) and  feat_cache  is  not None :
867+                 idx  =  feat_idx [0 ]
868+                 cache_x  =  x [:, :, - CACHE_T :, :, :].clone ()
869+                 if  cache_x .shape [2 ] <  2  and  feat_cache [idx ] is  not None :
870+                     # cache last frame of last two chunk 
871+                     cache_x  =  torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
872+                 x  =  layer (x , feat_cache [idx ])
873+                 feat_cache [idx ] =  cache_x 
874+                 feat_idx [0 ] +=  1  
875+             else :
876+                 x  =  layer (x )  
877+         return  x 
878+ 
879+ 
707880class  AutoencoderKLWan (ModelMixin , ConfigMixin ):
708881    r""" 
709882    A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. 
0 commit comments