55import math
66from math import sqrt , ceil
77from functools import partial
8+ from typing import Optional , Union , Tuple
89
910import torch
1011from torch import nn , einsum
@@ -207,12 +208,15 @@ def __init__(
207208 attn_dim_head = 64 ,
208209 attn_res_mp_add_t = 0.3 ,
209210 attn_flash = False ,
210- downsample = False
211+ downsample = False ,
212+ downsample_config : Tuple [bool , bool , bool ] = (True , True , True )
211213 ):
212214 super ().__init__ ()
213215 dim_out = default (dim_out , dim )
214216
215217 self .downsample = downsample
218+ self .downsample_config = downsample_config
219+
216220 self .downsample_conv = None
217221
218222 curr_dim = dim
@@ -259,7 +263,10 @@ def forward(
259263 ):
260264 if self .downsample :
261265 t , h , w = x .shape [- 3 :]
262- x = F .interpolate (x , (t // 2 , h // 2 , w // 2 ), mode = 'trilinear' )
266+ resize_factors = tuple ((2 if downsample else 1 ) for downsample in self .downsample_config )
267+ interpolate_shape = tuple (shape // factor for shape , factor in zip ((t , h , w ), resize_factors ))
268+
269+ x = F .interpolate (x , interpolate_shape , mode = 'trilinear' )
263270 x = self .downsample_conv (x )
264271
265272 x = self .pixel_norm (x )
@@ -294,12 +301,15 @@ def __init__(
294301 attn_dim_head = 64 ,
295302 attn_res_mp_add_t = 0.3 ,
296303 attn_flash = False ,
297- upsample = False
304+ upsample = False ,
305+ upsample_config : Tuple [bool , bool , bool ] = (True , True , True )
298306 ):
299307 super ().__init__ ()
300308 dim_out = default (dim_out , dim )
301309
302310 self .upsample = upsample
311+ self .upsample_config = upsample_config
312+
303313 self .needs_skip = not upsample
304314
305315 self .to_emb = None
@@ -341,7 +351,10 @@ def forward(
341351 ):
342352 if self .upsample :
343353 t , h , w = x .shape [- 3 :]
344- x = F .interpolate (x , (t * 2 , h * 2 , w * 2 ), mode = 'trilinear' )
354+ resize_factors = tuple ((2 if upsample else 1 ) for upsample in self .upsample_config )
355+ interpolate_shape = tuple (shape * factor for shape , factor in zip ((t , h , w ), resize_factors ))
356+
357+ x = F .interpolate (x , interpolate_shape , mode = 'trilinear' )
345358
346359 res = self .res_conv (x )
347360
@@ -416,12 +429,14 @@ def __init__(
416429 self ,
417430 * ,
418431 image_size ,
432+ frames ,
419433 dim = 192 ,
420434 dim_max = 768 , # channels will double every downsample and cap out to this value
421435 num_classes = None , # in paper, they do 1000 classes for a popular benchmark
422436 channels = 4 , # 4 channels in paper for some reason, must be alpha channel?
423437 num_downsamples = 3 ,
424- num_blocks_per_stage = 4 ,
438+ num_blocks_per_stage : Union [int , Tuple [int , ...]] = 4 ,
439+ downsample_types : Optional [Tuple [str , ...]] = None ,
425440 attn_res = (16 , 8 ),
426441 fourier_dim = 16 ,
427442 attn_dim_head = 64 ,
@@ -440,7 +455,9 @@ def __init__(
440455 # determine dimensions
441456
442457 self .channels = channels
458+ self .frames = frames
443459 self .image_size = image_size
460+
444461 input_channels = channels * (2 if self_condition else 1 )
445462
446463 # input and output blocks
@@ -478,6 +495,25 @@ def __init__(
478495
479496 self .num_downsamples = num_downsamples
480497
498+ # specifying downsample types (either image, frames, or both)
499+
500+ downsample_types = default (downsample_types , 'all' )
501+ downsample_types = cast_tuple (downsample_types , num_downsamples )
502+
503+ assert len (downsample_types ) == num_downsamples
504+ assert all ([t in {'all' , 'frame' , 'image' } for t in downsample_types ])
505+
506+ # number of blocks per downsample
507+
508+ num_blocks_per_stage = cast_tuple (num_blocks_per_stage , num_downsamples )
509+
510+ if len (num_blocks_per_stage ) == num_downsamples :
511+ first , * _ = num_blocks_per_stage
512+ num_blocks_per_stage = (first , * num_blocks_per_stage )
513+
514+ assert len (num_blocks_per_stage ) == (num_downsamples + 1 )
515+ assert all ([num_blocks >= 1 for num_blocks in num_blocks_per_stage ])
516+
481517 # attention
482518
483519 attn_res = set (cast_tuple (attn_res ))
@@ -498,17 +534,18 @@ def __init__(
498534 self .ups = ModuleList ([])
499535
500536 curr_dim = dim
501- curr_res = image_size
537+ curr_image_res = image_size
538+ curr_frame_res = frames
502539
503540 self .skip_mp_cat = MPCat (t = mp_cat_t , dim = 1 )
504541
505542 # take care of skip connection for initial input block and first three encoder blocks
506543
507544 prepend (self .ups , Decoder (dim * 2 , dim , ** block_kwargs ))
508545
509- assert num_blocks_per_stage >= 1
546+ init_num_blocks_per_stage , * rest_num_blocks_per_stage = num_blocks_per_stage
510547
511- for _ in range (num_blocks_per_stage ):
548+ for _ in range (init_num_blocks_per_stage ):
512549 enc = Encoder (curr_dim , curr_dim , ** block_kwargs )
513550 dec = Decoder (curr_dim * 2 , curr_dim , ** block_kwargs )
514551
@@ -517,20 +554,53 @@ def __init__(
517554
518555 # stages
519556
520- for _ in range (self .num_downsamples ):
557+ for _ , layer_num_blocks_per_stage , layer_downsample_type in zip (range (self .num_downsamples ), rest_num_blocks_per_stage , downsample_types ):
558+
521559 dim_out = min (dim_max , curr_dim * 2 )
522- upsample = Decoder (dim_out , curr_dim , has_attn = curr_res in attn_res , upsample = True , ** block_kwargs )
523560
524- curr_res //= 2
525- has_attn = curr_res in attn_res
561+ downsample_image = layer_downsample_type in { 'all' , 'image' }
562+ downsample_frame = layer_downsample_type in { 'all' , 'frame' }
526563
527- downsample = Encoder (curr_dim , dim_out , downsample = True , has_attn = has_attn , ** block_kwargs )
564+ assert not (downsample_image and not divisible_by (curr_image_res , 2 ))
565+ assert not (downsample_frame and not divisible_by (curr_frame_res , 2 ))
566+
567+ down_and_upsample_config = (
568+ downsample_frame ,
569+ downsample_image ,
570+ downsample_image
571+ )
572+
573+ upsample = Decoder (
574+ dim_out ,
575+ curr_dim ,
576+ has_attn = curr_image_res in attn_res ,
577+ upsample = True ,
578+ upsample_config = down_and_upsample_config ,
579+ ** block_kwargs
580+ )
581+
582+ if downsample_image :
583+ curr_image_res //= 2
584+
585+ if downsample_frame :
586+ curr_frame_res //= 2
587+
588+ has_attn = curr_image_res in attn_res
589+
590+ downsample = Encoder (
591+ curr_dim ,
592+ dim_out ,
593+ downsample = True ,
594+ downsample_config = down_and_upsample_config ,
595+ has_attn = has_attn ,
596+ ** block_kwargs
597+ )
528598
529599 append (self .downs , downsample )
530600 prepend (self .ups , upsample )
531601 prepend (self .ups , Decoder (dim_out * 2 , dim_out , has_attn = has_attn , ** block_kwargs ))
532602
533- for _ in range (num_blocks_per_stage ):
603+ for _ in range (layer_num_blocks_per_stage ):
534604 enc = Encoder (dim_out , dim_out , has_attn = has_attn , ** block_kwargs )
535605 dec = Decoder (dim_out * 2 , dim_out , has_attn = has_attn , ** block_kwargs )
536606
@@ -541,7 +611,7 @@ def __init__(
541611
542612 # take care of the two middle decoders
543613
544- mid_has_attn = curr_res in attn_res
614+ mid_has_attn = curr_image_res in attn_res
545615
546616 self .mids = ModuleList ([
547617 Decoder (curr_dim , curr_dim , has_attn = mid_has_attn , ** block_kwargs ),
@@ -563,7 +633,7 @@ def forward(
563633 ):
564634 # validate image shape
565635
566- assert x .shape [1 :] == (self .channels , self .image_size , self .image_size , self .image_size )
636+ assert x .shape [1 :] == (self .channels , self .frames , self .image_size , self .image_size )
567637
568638 # self conditioning
569639
@@ -689,19 +759,30 @@ def forward(self, x):
689759# example
690760
691761if __name__ == '__main__' :
762+
692763 unet = KarrasUnet3D (
764+ frames = 32 ,
693765 image_size = 64 ,
694- dim = 192 ,
766+ dim = 8 ,
695767 dim_max = 768 ,
768+ num_downsamples = 6 ,
769+ num_blocks_per_stage = (4 , 3 , 2 , 2 , 2 , 2 ),
770+ downsample_types = (
771+ 'image' ,
772+ 'frame' ,
773+ 'image' ,
774+ 'frame' ,
775+ 'image' ,
776+ 'frame' ,
777+ ),
778+ attn_dim_head = 8 ,
696779 num_classes = 1000 ,
697780 )
698781
699- images = torch .randn (2 , 4 , 64 , 64 , 64 )
782+ images = torch .randn (2 , 4 , 32 , 64 , 64 )
700783
701784 denoised_images = unet (
702785 images ,
703786 time = torch .ones (2 ,),
704787 class_labels = torch .randint (0 , 1000 , (2 ,))
705788 )
706-
707- assert denoised_images .shape == images .shape
0 commit comments