@@ -271,7 +271,7 @@ def __init__(self, dim: int, mode: str) -> None:
271271    def  forward (self , x , feat_cache = None , feat_idx = [0 ]):
272272        b , c , t , h , w  =  x .size ()
273273        if  self .mode  ==  'upsample3d' :     
274-             if  feat_cache  is  not None :                     
274+             if  feat_cache  is  not None :
275275                idx  =  feat_idx [0 ]
276276                if  feat_cache [idx ] is  None : 
277277                    feat_cache [idx ] =  'Rep' 
@@ -403,7 +403,7 @@ def __init__(self, dim):
403403        self .proj  =  nn .Conv2d (dim , dim , 1 )
404404
405405
406-     def  forward (self , x ):  
406+     def  forward (self , x ):
407407        identity  =  x 
408408        batch_size , channels , time , height , width  =  x .size ()
409409
@@ -427,7 +427,7 @@ def forward(self, x):
427427        # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w] 
428428        x  =  x .view (batch_size , time , channels , height , width )
429429        x  =  x .permute (0 , 2 , 1 , 3 , 4 )
430-          
430+ 
431431        return  x  +  identity 
432432
433433
@@ -584,7 +584,6 @@ class WanUpBlock(nn.Module):
584584        out_dim (int): Output dimension 
585585        num_res_blocks (int): Number of residual blocks 
586586        dropout (float): Dropout rate 
587-         use_attention (bool): Whether to use attention 
588587        upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d') 
589588        non_linearity (str): Type of non-linearity to use 
590589    """ 
@@ -594,7 +593,6 @@ def __init__(
594593        out_dim : int ,
595594        num_res_blocks : int ,
596595        dropout : float  =  0.0 ,
597-         use_attention : bool  =  False ,
598596        upsample_mode : Optional [str ] =  None ,
599597        non_linearity : str  =  "silu" ,
600598    ):
@@ -604,17 +602,13 @@ def __init__(
604602
605603        # Create layers list 
606604        resnets  =  []
607-         attentions  =  []
608605        # Add residual blocks and attention if needed 
609606        current_dim  =  in_dim 
610607        for  _  in  range (num_res_blocks  +  1 ):
611608            resnets .append (WanResidualBlock (current_dim , out_dim , dropout , non_linearity ))
612-             if  use_attention :
613-                 attentions .append (WanAttentionBlock (out_dim ))
614609            current_dim  =  out_dim 
615610
616611        self .resnets  =  nn .ModuleList (resnets )
617-         self .attentions  =  nn .ModuleList (attentions )
618612
619613        # Add upsampling layer if needed 
620614        self .upsamplers  =  None 
@@ -635,128 +629,21 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
635629        Returns: 
636630            torch.Tensor: Output tensor 
637631        """ 
638-         for  resnet ,  attention   in  zip ( self .resnets ,  self . attentions ) :
632+         for  resnet   in  self .resnets :
639633            if  feat_cache  is  not None :
640634                x  =  resnet (x , feat_cache , feat_idx )
641635            else :
642636                x  =  resnet (x )
643-             if  attention  is  not None :
644-                 x  =  attention (x )
645-         if  self .upsamplers  is  not None :
646-             x  =  self .upsamplers [0 ](x )
647-         return  x 
648- 
649- 
650- class  WanDecoder3d (nn .Module ):
651-     r""" 
652-     A 3D decoder module. 
653- 
654-     Args: 
655-         dim (int): The base number of channels in the first layer. 
656-         z_dim (int): The dimensionality of the latent space. 
657-         dim_mult (list of int): Multipliers for the number of channels in each block. 
658-         num_res_blocks (int): Number of residual blocks in each block. 
659-         attn_scales (list of float): Scales at which to apply attention mechanisms. 
660-         temperal_upsample (list of bool): Whether to upsample temporally in each block. 
661-         dropout (float): Dropout rate for the dropout layers. 
662-         non_linearity (str): Type of non-linearity to use. 
663-     """ 
664-     def  __init__ (
665-         self ,
666-         dim = 128 ,
667-         z_dim = 4 ,
668-         dim_mult = [1 , 2 , 4 , 4 ],
669-         num_res_blocks = 2 ,
670-         attn_scales = [],
671-         temperal_upsample = [False , True , True ],
672-         dropout = 0.0 ,
673-         non_linearity : str  =  "silu" ,
674-     ):
675-         super ().__init__ ()
676-         self .dim  =  dim 
677-         self .z_dim  =  z_dim 
678-         self .dim_mult  =  dim_mult 
679-         self .num_res_blocks  =  num_res_blocks 
680-         self .attn_scales  =  attn_scales 
681-         self .temperal_upsample  =  temperal_upsample 
682- 
683-         self .nonlinearity  =  get_activation (non_linearity )
684- 
685-         # dimensions 
686-         dims  =  [dim  *  u  for  u  in  [dim_mult [- 1 ]] +  dim_mult [::- 1 ]]
687-         scale  =  1.0  /  2  **  (len (dim_mult ) -  2 )
688- 
689-         # init block 
690-         self .conv1  =  WanCausalConv3d (z_dim , dims [0 ], 3 , padding = 1 )
691- 
692-         # middle blocks 
693-         self .middle  =  WanMidBlock (dims [0 ], dropout , non_linearity , num_layers = 1 )
694-         # upsample blocks 
695-         upsamples  =  []
696-         for  i , (in_dim , out_dim ) in  enumerate (zip (dims [:- 1 ], dims [1 :])):
697-             # residual (+attention) blocks 
698-             if  i  >  0 :
699-                 in_dim  =  in_dim  // 2 
700-             for  _  in  range (num_res_blocks  +  1 ):
701-                 upsamples .append (WanResidualBlock (in_dim , out_dim , dropout ))
702-                 if  scale  in  attn_scales :
703-                     upsamples .append (WanAttentionBlock (out_dim ))
704-                 in_dim  =  out_dim 
705-             
706-             # upsample block 
707-             if  i  !=  len (dim_mult ) -  1 :
708-                 mode  =  'upsample3d'  if  temperal_upsample [i ] else  'upsample2d' 
709-                 upsamples .append (WanResample (out_dim , mode = mode ))
710-                 scale  *=  2.0 
711-         self .upsamples  =  nn .Sequential (* upsamples )
712637
713-         # output blocks 
714-         self .head  =  nn .Sequential (
715-             WanRMS_norm (out_dim , images = False ),
716-             self .nonlinearity ,
717-             WanCausalConv3d (out_dim , 3 , 3 , padding = 1 )
718-         )
719-     
720-     def  forward (self , x , feat_cache = None , feat_idx = [0 ]):
721-         ## conv1 
722-         if  feat_cache  is  not None :
723-             idx  =  feat_idx [0 ]
724-             cache_x  =  x [:, :, - CACHE_T :, :, :].clone ()
725-             if  cache_x .shape [2 ] <  2  and  feat_cache [idx ] is  not None :
726-                 # cache last frame of last two chunk 
727-                 cache_x  =  torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
728-             x  =  self .conv1 (x , feat_cache [idx ])
729-             feat_cache [idx ] =  cache_x 
730-             feat_idx [0 ] +=  1  
731-         else :
732-             x  =  self .conv1 (x ) 
733-         
734-         x  =  self .middle (x , feat_cache , feat_idx )
735- 
736-         ## upsamples  
737-         for  layer  in  self .upsamples :
638+         if  self .upsamplers  is  not None :
738639            if  feat_cache  is  not None :
739-                 x  =  layer (x , feat_cache , feat_idx )  
740-             else :
741-                 x  =  layer (x ) 
742- 
743-         ## head  
744-         for  layer  in  self .head :
745-             if  isinstance (layer , WanCausalConv3d ) and  feat_cache  is  not None :
746-                 idx  =  feat_idx [0 ]
747-                 cache_x  =  x [:, :, - CACHE_T :, :, :].clone ()
748-                 if  cache_x .shape [2 ] <  2  and  feat_cache [idx ] is  not None :
749-                     # cache last frame of last two chunk 
750-                     cache_x  =  torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
751-                 x  =  layer (x , feat_cache [idx ])
752-                 feat_cache [idx ] =  cache_x 
753-                 feat_idx [0 ] +=  1  
640+                 x  =  self .upsamplers [0 ](x , feat_cache , feat_idx )
754641            else :
755-                 x  =  layer (x )   
642+                 x  =  self . upsamplers [ 0 ] (x )
756643        return  x 
757644
758645
759- class  WanDecoder3dYiYi (nn .Module ):
646+ class  WanDecoder3d (nn .Module ):
760647    r""" 
761648    A 3D decoder module. 
762649
@@ -809,8 +696,7 @@ def __init__(
809696            if  i  >  0 :
810697                in_dim  =  in_dim  //  2 
811698
812-             # Determine if we need attention and upsampling 
813-             use_attention  =  scale  in  attn_scales 
699+             # Determine if we need upsampling 
814700            upsample_mode  =  None 
815701            if  i  !=  len (dim_mult ) -  1 :
816702                upsample_mode  =  'upsample3d'  if  temperal_upsample [i ] else  'upsample2d' 
@@ -821,7 +707,6 @@ def __init__(
821707                out_dim = out_dim ,
822708                num_res_blocks = num_res_blocks ,
823709                dropout = dropout ,
824-                 use_attention = use_attention ,
825710                upsample_mode = upsample_mode ,
826711                non_linearity = non_linearity ,
827712            )
0 commit comments