2929from ..upsampling import Upsample2D
3030
3131
32- class AllegroTemporalConvBlock (nn .Module ):
32+ class AllegroTemporalConvLayer (nn .Module ):
3333 r"""
3434 Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from:
3535 https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
@@ -40,6 +40,7 @@ def __init__(
4040 in_dim : int ,
4141 out_dim : Optional [int ] = None ,
4242 dropout : float = 0.0 ,
43+ norm_num_groups : int = 32 ,
4344 up_sample : bool = False ,
4445 down_sample : bool = False ,
4546 stride : int = 1 ,
@@ -55,44 +56,40 @@ def __init__(
5556
5657 if down_sample :
5758 self .conv1 = nn .Sequential (
58- nn .GroupNorm (32 , in_dim ),
59+ nn .GroupNorm (norm_num_groups , in_dim ),
5960 nn .SiLU (),
6061 nn .Conv3d (in_dim , out_dim , (2 , stride , stride ), stride = (2 , 1 , 1 ), padding = (0 , pad_h , pad_w )),
6162 )
6263 elif up_sample :
6364 self .conv1 = nn .Sequential (
64- nn .GroupNorm (32 , in_dim ),
65+ nn .GroupNorm (norm_num_groups , in_dim ),
6566 nn .SiLU (),
6667 nn .Conv3d (in_dim , out_dim * 2 , (1 , stride , stride ), padding = (0 , pad_h , pad_w )),
6768 )
6869 else :
6970 self .conv1 = nn .Sequential (
70- nn .GroupNorm (32 , in_dim ),
71+ nn .GroupNorm (norm_num_groups , in_dim ),
7172 nn .SiLU (),
7273 nn .Conv3d (in_dim , out_dim , (3 , stride , stride ), padding = (pad_t , pad_h , pad_w )),
7374 )
7475 self .conv2 = nn .Sequential (
75- nn .GroupNorm (32 , out_dim ),
76+ nn .GroupNorm (norm_num_groups , out_dim ),
7677 nn .SiLU (),
7778 nn .Dropout (dropout ),
7879 nn .Conv3d (out_dim , in_dim , (3 , stride , stride ), padding = (pad_t , pad_h , pad_w )),
7980 )
8081 self .conv3 = nn .Sequential (
81- nn .GroupNorm (32 , out_dim ),
82+ nn .GroupNorm (norm_num_groups , out_dim ),
8283 nn .SiLU (),
8384 nn .Dropout (dropout ),
8485 nn .Conv3d (out_dim , in_dim , (3 , stride , stride ), padding = (pad_t , pad_h , pad_h )),
8586 )
8687 self .conv4 = nn .Sequential (
87- nn .GroupNorm (32 , out_dim ),
88+ nn .GroupNorm (norm_num_groups , out_dim ),
8889 nn .SiLU (),
8990 nn .Conv3d (out_dim , in_dim , (3 , stride , stride ), padding = (pad_t , pad_h , pad_h )),
9091 )
9192
92- # zero out the last layer params, so the conv block is identity
93- nn .init .zeros_ (self .conv4 [- 1 ].weight )
94- nn .init .zeros_ (self .conv4 [- 1 ].bias )
95-
9693 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
9794 identity = hidden_states
9895
@@ -169,19 +166,20 @@ def __init__(
169166 )
170167 )
171168 temp_convs .append (
172- AllegroTemporalConvBlock (
169+ AllegroTemporalConvLayer (
173170 out_channels ,
174171 out_channels ,
175172 dropout = 0.1 ,
173+ norm_num_groups = resnet_groups ,
176174 )
177175 )
178176
179177 self .resnets = nn .ModuleList (resnets )
180178 self .temp_convs = nn .ModuleList (temp_convs )
181179
182180 if add_temp_downsample :
183- self .temp_convs_down = AllegroTemporalConvBlock (
184- out_channels , out_channels , dropout = 0.1 , down_sample = True , stride = 3
181+ self .temp_convs_down = AllegroTemporalConvLayer (
182+ out_channels , out_channels , dropout = 0.1 , norm_num_groups = resnet_groups , down_sample = True , stride = 3
185183 )
186184 self .add_temp_downsample = add_temp_downsample
187185
@@ -258,10 +256,11 @@ def __init__(
258256 )
259257 )
260258 temp_convs .append (
261- AllegroTemporalConvBlock (
259+ AllegroTemporalConvLayer (
262260 out_channels ,
263261 out_channels ,
264262 dropout = 0.1 ,
263+ norm_num_groups = resnet_groups ,
265264 )
266265 )
267266
@@ -270,8 +269,8 @@ def __init__(
270269
271270 self .add_temp_upsample = add_temp_upsample
272271 if add_temp_upsample :
273- self .temp_conv_up = AllegroTemporalConvBlock (
274- out_channels , out_channels , dropout = 0.1 , up_sample = True , stride = 3
272+ self .temp_conv_up = AllegroTemporalConvLayer (
273+ out_channels , out_channels , dropout = 0.1 , norm_num_groups = resnet_groups , up_sample = True , stride = 3
275274 )
276275
277276 if self .add_upsample :
@@ -336,10 +335,11 @@ def __init__(
336335 )
337336 ]
338337 temp_convs = [
339- AllegroTemporalConvBlock (
338+ AllegroTemporalConvLayer (
340339 in_channels ,
341340 in_channels ,
342341 dropout = 0.1 ,
342+ norm_num_groups = resnet_groups ,
343343 )
344344 ]
345345 attentions = []
@@ -383,10 +383,11 @@ def __init__(
383383 )
384384
385385 temp_convs .append (
386- AllegroTemporalConvBlock (
386+ AllegroTemporalConvLayer (
387387 in_channels ,
388388 in_channels ,
389389 dropout = 0.1 ,
390+ norm_num_groups = resnet_groups ,
390391 )
391392 )
392393
@@ -513,6 +514,7 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
513514 sample = sample + residual
514515
515516 if self .gradient_checkpointing :
517+
516518 def create_custom_forward (module ):
517519 def custom_forward (* inputs ):
518520 return module (* inputs )
@@ -655,24 +657,19 @@ def forward(self, sample: torch.Tensor) -> torch.Tensor:
655657 upscale_dtype = next (iter (self .up_blocks .parameters ())).dtype
656658
657659 if self .gradient_checkpointing :
660+
658661 def create_custom_forward (module ):
659662 def custom_forward (* inputs ):
660663 return module (* inputs )
661664
662665 return custom_forward
663666
664667 # Mid block
665- sample = torch .utils .checkpoint .checkpoint (
666- create_custom_forward (self .mid_block ),
667- sample
668- )
668+ sample = torch .utils .checkpoint .checkpoint (create_custom_forward (self .mid_block ), sample )
669669
670670 # Up blocks
671671 for up_block in self .up_blocks :
672- sample = torch .utils .checkpoint .checkpoint (
673- create_custom_forward (up_block ),
674- sample
675- )
672+ sample = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), sample )
676673
677674 else :
678675 # Mid block
0 commit comments