@@ -530,11 +530,10 @@ def __init__(
530530 self .middle = WanMidBlock (out_dim , dropout , non_linearity , num_layers = 1 )
531531
532532 # output blocks
533- self .head = nn .Sequential (
534- WanRMS_norm (out_dim , images = False ),
535- self .nonlinearity ,
536- WanCausalConv3d (out_dim , z_dim , 3 , padding = 1 )
537- )
533+ self .norm_out = WanRMS_norm (out_dim , images = False )
534+ self .conv_out = WanCausalConv3d (out_dim , z_dim , 3 , padding = 1 )
535+
536+ self .gradient_checkpointing = False
538537
539538 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
540539 if feat_cache is not None :
@@ -560,18 +559,19 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
560559 x = self .middle (x , feat_cache , feat_idx )
561560
562561 ## head
563- for layer in self .head :
564- if isinstance (layer , WanCausalConv3d ) and feat_cache is not None :
565- idx = feat_idx [0 ]
566- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
567- if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
562+ x = self .norm_out (x )
563+ x = self .nonlinearity (x )
564+ if feat_cache is not None :
565+ idx = feat_idx [0 ]
566+ cache_x = x [:, :, - CACHE_T :, :, :].clone ()
567+ if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
568568 # cache last frame of last two chunk
569- cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
570- x = layer (x , feat_cache [idx ])
571- feat_cache [idx ] = cache_x
572- feat_idx [0 ] += 1
573- else :
574- x = layer (x )
569+ cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
570+ x = self . conv_out (x , feat_cache [idx ])
571+ feat_cache [idx ] = cache_x
572+ feat_idx [0 ] += 1
573+ else :
574+ x = self . conv_out (x )
575575 return x
576576
577577
@@ -719,11 +719,10 @@ def __init__(
719719 self .upsamples = upsamples
720720
721721 # output blocks
722- self .head = nn .Sequential (
723- WanRMS_norm (out_dim , images = False ),
724- self .nonlinearity ,
725- WanCausalConv3d (out_dim , 3 , 3 , padding = 1 )
726- )
722+ self .norm_out = WanRMS_norm (out_dim , images = False )
723+ self .conv_out = WanCausalConv3d (out_dim , 3 , 3 , padding = 1 )
724+
725+ self .gradient_checkpointing = False
727726
728727 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
729728 ## conv1
@@ -747,18 +746,19 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
747746 x = up_block (x , feat_cache , feat_idx )
748747
749748 ## head
750- for layer in self .head :
751- if isinstance (layer , WanCausalConv3d ) and feat_cache is not None :
752- idx = feat_idx [0 ]
753- cache_x = x [:, :, - CACHE_T :, :, :].clone ()
754- if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
755- # cache last frame of last two chunk
756- cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
757- x = layer (x , feat_cache [idx ])
758- feat_cache [idx ] = cache_x
759- feat_idx [0 ] += 1
760- else :
761- x = layer (x )
749+ x = self .norm_out (x )
750+ x = self .nonlinearity (x )
751+ if feat_cache is not None :
752+ idx = feat_idx [0 ]
753+ cache_x = x [:, :, - CACHE_T :, :, :].clone ()
754+ if cache_x .shape [2 ] < 2 and feat_cache [idx ] is not None :
755+ # cache last frame of last two chunk
756+ cache_x = torch .cat ([feat_cache [idx ][:, :, - 1 , :, :].unsqueeze (2 ).to (cache_x .device ), cache_x ], dim = 2 )
757+ x = self .conv_out (x , feat_cache [idx ])
758+ feat_cache [idx ] = cache_x
759+ feat_idx [0 ] += 1
760+ else :
761+ x = self .conv_out (x )
762762 return x
763763
764764
0 commit comments