@@ -86,90 +86,6 @@ def forward(self, x, cache_x=None):
8686 return super ().forward (x )
8787
8888
89- # TODO: not used yet, will not affect the state dict so can be refactored in follow up PR
90- class WanCausalConv3dYiYi (nn .Conv3d ):
91- r"""
92- A custom 3D causal convolution layer with feature caching support.
93-
94- This layer extends the standard Conv3D layer by ensuring causality in the time dimension
95- and handling feature caching for efficient inference.
96-
97- Args:
98- in_channels (int): Number of channels in the input image
99- out_channels (int): Number of channels produced by the convolution
100- kernel_size (int or tuple): Size of the convolving kernel
101- stride (int or tuple, optional): Stride of the convolution. Default: 1
102- padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
103- """
104- def __init__ (
105- self ,
106- in_channels : int ,
107- out_channels : int ,
108- kernel_size : Union [int , Tuple [int , int , int ]],
109- stride : Union [int , Tuple [int , int , int ]] = 1 ,
110- padding : Union [int , Tuple [int , int , int ]] = 0 ,
111- ) -> None :
112- super ().__init__ (
113- in_channels = in_channels ,
114- out_channels = out_channels ,
115- kernel_size = kernel_size ,
116- stride = stride ,
117- padding = padding ,
118- )
119-
120- # Set up causal padding
121- self ._padding = (
122- self .padding [2 ],
123- self .padding [2 ],
124- self .padding [1 ],
125- self .padding [1 ],
126- 2 * self .padding [0 ],
127- 0
128- )
129- self .padding = (0 , 0 , 0 )
130-
131- def forward (self , x , feat_cache = None , feat_idx = [0 ]):
132- """
133- Forward pass with feature caching support.
134-
135- Args:
136- x (torch.Tensor): Input tensor
137- feat_cache (list, optional): List to store cached features
138- feat_idx (list, optional): List with a single integer indicating the current cache index
139-
140- Returns:
141- torch.Tensor: Output tensor after convolution
142- """
143- # Handle feature caching
144- if feat_cache is not None :
145- idx = feat_idx [0 ]
146- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
147-
148- # Concatenate with cached frame if available
149- if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
150- # cache last frame of last two chunk
151- cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
152-
153- # Apply padding and convolution with cached data
154- padding = list (self ._padding )
155- if feat_cache [idx ] is not None and self ._padding [4 ] > 0 :
156- x = torch .cat ([feat_cache [idx ], x ], dim = 2 )
157- padding [4 ] -= feat_cache [idx ].shape [2 ]
158-
159- x = F .pad (x , padding )
160- result = super ().forward (x )
161-
162- # Update cache
163- feat_cache [idx ] = cache_x
164- feat_idx [0 ] += 1
165-
166- return result
167- else :
168- # Standard forward pass without caching
169- x = F .pad (x , self ._padding )
170- return super ().forward (x )
171-
172-
17389class WanRMS_norm (nn .Module ):
17490 r"""
17591 A custom RMS normalization layer.
@@ -501,26 +417,26 @@ def __init__(
501417 scale = 1.0
502418
503419 # init block
504- self .conv1 = WanCausalConv3d (3 , dims [0 ], 3 , padding = 1 )
420+ self .conv_in = WanCausalConv3d (3 , dims [0 ], 3 , padding = 1 )
505421
506422 # downsample blocks
507- self .downsamples = nn .ModuleList ([])
423+ self .down_blocks = nn .ModuleList ([])
508424 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
509425 # residual (+attention) blocks
510426 for _ in range (num_res_blocks ):
511- self .downsamples .append (WanResidualBlock (in_dim , out_dim , dropout ))
427+ self .down_blocks .append (WanResidualBlock (in_dim , out_dim , dropout ))
512428 if scale in attn_scales :
513- self .downsamples .append (WanAttentionBlock (out_dim ))
429+ self .down_blocks .append (WanAttentionBlock (out_dim ))
514430 in_dim = out_dim
515431
516432 # downsample block
517433 if i != len (dim_mult ) - 1 :
518434 mode = 'downsample3d' if temperal_downsample [i ] else 'downsample2d'
519- self .downsamples .append (WanResample (out_dim , mode = mode ))
435+ self .down_blocks .append (WanResample (out_dim , mode = mode ))
520436 scale /= 2.0
521437
522438 # middle blocks
523- self .middle = WanMidBlock (out_dim , dropout , non_linearity , num_layers = 1 )
439+ self .mid_block = WanMidBlock (out_dim , dropout , non_linearity , num_layers = 1 )
524440
525441 # output blocks
526442 self .norm_out = WanRMS_norm (out_dim , images = False )
@@ -535,21 +451,21 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
535451 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
536452 # cache last frame of last two chunk
537453 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
538- x = self .conv1 (x , feat_cache [idx ])
454+ x = self .conv_in (x , feat_cache [idx ])
539455 feat_cache [idx ] = cache_x
540456 feat_idx [0 ] += 1
541457 else :
542- x = self .conv1 (x )
458+ x = self .conv_in (x )
543459
544460 ## downsamples
545- for layer in self .downsamples :
461+ for layer in self .down_blocks :
546462 if feat_cache is not None :
547463 x = layer (x , feat_cache , feat_idx )
548464 else :
549465 x = layer (x )
550466
551467 ## middle
552- x = self .middle (x , feat_cache , feat_idx )
468+ x = self .mid_block (x , feat_cache , feat_idx )
553469
554470 ## head
555471 x = self .norm_out (x )
@@ -676,14 +592,14 @@ def __init__(
676592 scale = 1.0 / 2 ** (len (dim_mult ) - 2 )
677593
678594 # init block
679- self .conv1 = WanCausalConv3d (z_dim , dims [0 ], 3 , padding = 1 )
595+ self .conv_in = WanCausalConv3d (z_dim , dims [0 ], 3 , padding = 1 )
680596
681597 # middle blocks
682- self .middle = WanMidBlock (dims [0 ], dropout , non_linearity , num_layers = 1 )
598+ self .mid_block = WanMidBlock (dims [0 ], dropout , non_linearity , num_layers = 1 )
683599
684600
685601 # upsample blocks
686- upsamples = nn .ModuleList ([])
602+ self . up_blocks = nn .ModuleList ([])
687603 for i , (in_dim , out_dim ) in enumerate (zip (dims [:- 1 ], dims [1 :])):
688604 # residual (+attention) blocks
689605 if i > 0 :
@@ -703,14 +619,12 @@ def __init__(
703619 upsample_mode = upsample_mode ,
704620 non_linearity = non_linearity ,
705621 )
706- upsamples .append (up_block )
622+ self . up_blocks .append (up_block )
707623
708624 # Update scale for next iteration
709625 if upsample_mode is not None :
710626 scale *= 2.0
711627
712- self .upsamples = upsamples
713-
714628 # output blocks
715629 self .norm_out = WanRMS_norm (out_dim , images = False )
716630 self .conv_out = WanCausalConv3d (out_dim , 3 , 3 , padding = 1 )
@@ -725,17 +639,17 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
725639 if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
726640 # cache last frame of last two chunk
727641 cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
728- x = self .conv1 (x , feat_cache [idx ])
642+ x = self .conv_in (x , feat_cache [idx ])
729643 feat_cache [idx ] = cache_x
730644 feat_idx [0 ] += 1
731645 else :
732- x = self .conv1 (x )
646+ x = self .conv_in (x )
733647
734648 ## middle
735- x = self .middle (x , feat_cache , feat_idx )
649+ x = self .mid_block (x , feat_cache , feat_idx )
736650
737651 ## upsamples
738- for up_block in self .upsamples :
652+ for up_block in self .up_blocks :
739653 x = up_block (x , feat_cache , feat_idx )
740654
741655 ## head
@@ -796,8 +710,8 @@ def __init__(
796710 base_dim , z_dim * 2 , dim_mult , num_res_blocks , attn_scales ,
797711 self .temperal_downsample , dropout
798712 )
799- self .conv1 = WanCausalConv3d (z_dim * 2 , z_dim * 2 , 1 )
800- self .conv2 = WanCausalConv3d (z_dim , z_dim , 1 )
713+ self .quant_conv = WanCausalConv3d (z_dim * 2 , z_dim * 2 , 1 )
714+ self .post_quant_conv = WanCausalConv3d (z_dim , z_dim , 1 )
801715
802716 self .decoder = WanDecoder3d (
803717 base_dim , z_dim , dim_mult , num_res_blocks , attn_scales ,
@@ -834,7 +748,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
834748 out_ = self .encoder (x [:,:,1 + 4 * (i - 1 ):1 + 4 * i ,:,:], feat_cache = self ._enc_feat_map , feat_idx = self ._enc_conv_idx )
835749 out = torch .cat ([out , out_ ], 2 )
836750
837- enc = self .conv1 (out )
751+ enc = self .quant_conv (out )
838752 mu , logvar = enc [:, :self .z_dim , :, :, :], enc [:, self .z_dim :, :, :, :]
839753 mu = (mu - self .scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )) * self .scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 )
840754 logvar = (logvar - self .scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )) * self .scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 )
@@ -870,7 +784,7 @@ def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[Dec
870784 z = z / self .scale [1 ].view (1 , self .z_dim , 1 , 1 , 1 ) + self .scale [0 ].view (1 , self .z_dim , 1 , 1 , 1 )
871785
872786 iter_ = z .shape [2 ]
873- x = self .conv2 (z )
787+ x = self .post_quant_conv (z )
874788 for i in range (iter_ ):
875789 self ._conv_idx = [0 ]
876790 if i == 0 :
0 commit comments