@@ -58,7 +58,9 @@ def get_down_block(
5858 resnet_time_scale_shift : str = "default" ,
5959 temporal_num_attention_heads : int = 8 ,
6060 temporal_max_seq_length : int = 32 ,
61- transformer_layers_per_block : int = 1 ,
61+ transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
62+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
63+ dropout : float = 0.0 ,
6264) -> Union [
6365 "DownBlock3D" ,
6466 "CrossAttnDownBlock3D" ,
@@ -79,6 +81,7 @@ def get_down_block(
7981 resnet_groups = resnet_groups ,
8082 downsample_padding = downsample_padding ,
8183 resnet_time_scale_shift = resnet_time_scale_shift ,
84+ dropout = dropout ,
8285 )
8386 elif down_block_type == "CrossAttnDownBlock3D" :
8487 if cross_attention_dim is None :
@@ -100,6 +103,7 @@ def get_down_block(
100103 only_cross_attention = only_cross_attention ,
101104 upcast_attention = upcast_attention ,
102105 resnet_time_scale_shift = resnet_time_scale_shift ,
106+ dropout = dropout ,
103107 )
104108 if down_block_type == "DownBlockMotion" :
105109 return DownBlockMotion (
@@ -115,6 +119,8 @@ def get_down_block(
115119 resnet_time_scale_shift = resnet_time_scale_shift ,
116120 temporal_num_attention_heads = temporal_num_attention_heads ,
117121 temporal_max_seq_length = temporal_max_seq_length ,
122+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
123+ dropout = dropout ,
118124 )
119125 elif down_block_type == "CrossAttnDownBlockMotion" :
120126 if cross_attention_dim is None :
@@ -139,6 +145,8 @@ def get_down_block(
139145 resnet_time_scale_shift = resnet_time_scale_shift ,
140146 temporal_num_attention_heads = temporal_num_attention_heads ,
141147 temporal_max_seq_length = temporal_max_seq_length ,
148+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
149+ dropout = dropout ,
142150 )
143151 elif down_block_type == "DownBlockSpatioTemporal" :
144152 # added for SDV
@@ -189,7 +197,8 @@ def get_up_block(
189197 temporal_num_attention_heads : int = 8 ,
190198 temporal_cross_attention_dim : Optional [int ] = None ,
191199 temporal_max_seq_length : int = 32 ,
192- transformer_layers_per_block : int = 1 ,
200+ transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
201+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
193202 dropout : float = 0.0 ,
194203) -> Union [
195204 "UpBlock3D" ,
@@ -212,6 +221,7 @@ def get_up_block(
212221 resnet_groups = resnet_groups ,
213222 resnet_time_scale_shift = resnet_time_scale_shift ,
214223 resolution_idx = resolution_idx ,
224+ dropout = dropout ,
215225 )
216226 elif up_block_type == "CrossAttnUpBlock3D" :
217227 if cross_attention_dim is None :
@@ -234,6 +244,7 @@ def get_up_block(
234244 upcast_attention = upcast_attention ,
235245 resnet_time_scale_shift = resnet_time_scale_shift ,
236246 resolution_idx = resolution_idx ,
247+ dropout = dropout ,
237248 )
238249 if up_block_type == "UpBlockMotion" :
239250 return UpBlockMotion (
@@ -250,6 +261,8 @@ def get_up_block(
250261 resolution_idx = resolution_idx ,
251262 temporal_num_attention_heads = temporal_num_attention_heads ,
252263 temporal_max_seq_length = temporal_max_seq_length ,
264+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
265+ dropout = dropout ,
253266 )
254267 elif up_block_type == "CrossAttnUpBlockMotion" :
255268 if cross_attention_dim is None :
@@ -275,6 +288,8 @@ def get_up_block(
275288 resolution_idx = resolution_idx ,
276289 temporal_num_attention_heads = temporal_num_attention_heads ,
277290 temporal_max_seq_length = temporal_max_seq_length ,
291+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
292+ dropout = dropout ,
278293 )
279294 elif up_block_type == "UpBlockSpatioTemporal" :
280295 # added for SDV
@@ -948,14 +963,31 @@ def __init__(
948963 output_scale_factor : float = 1.0 ,
949964 add_downsample : bool = True ,
950965 downsample_padding : int = 1 ,
951- temporal_num_attention_heads : int = 1 ,
966+ temporal_num_attention_heads : Union [ int , Tuple [ int ]] = 1 ,
952967 temporal_cross_attention_dim : Optional [int ] = None ,
953968 temporal_max_seq_length : int = 32 ,
969+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
954970 ):
955971 super ().__init__ ()
956972 resnets = []
957973 motion_modules = []
958974
975+ # support for variable transformer layers per temporal block
976+ if isinstance (temporal_transformer_layers_per_block , int ):
977+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
978+ elif len (temporal_transformer_layers_per_block ) != num_layers :
979+ raise ValueError (
980+ f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length { num_layers } "
981+ )
982+
983+ # support for variable number of attention head per temporal layers
984+ if isinstance (temporal_num_attention_heads , int ):
985+ temporal_num_attention_heads = (temporal_num_attention_heads ,) * num_layers
986+ elif len (temporal_num_attention_heads ) != num_layers :
987+ raise ValueError (
988+ f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length { num_layers } "
989+ )
990+
959991 for i in range (num_layers ):
960992 in_channels = in_channels if i == 0 else out_channels
961993 resnets .append (
@@ -974,15 +1006,16 @@ def __init__(
9741006 )
9751007 motion_modules .append (
9761008 TransformerTemporalModel (
977- num_attention_heads = temporal_num_attention_heads ,
1009+ num_attention_heads = temporal_num_attention_heads [ i ] ,
9781010 in_channels = out_channels ,
1011+ num_layers = temporal_transformer_layers_per_block [i ],
9791012 norm_num_groups = resnet_groups ,
9801013 cross_attention_dim = temporal_cross_attention_dim ,
9811014 attention_bias = False ,
9821015 activation_fn = "geglu" ,
9831016 positional_embeddings = "sinusoidal" ,
9841017 num_positional_embeddings = temporal_max_seq_length ,
985- attention_head_dim = out_channels // temporal_num_attention_heads ,
1018+ attention_head_dim = out_channels // temporal_num_attention_heads [ i ] ,
9861019 )
9871020 )
9881021
@@ -1065,7 +1098,7 @@ def __init__(
10651098 temb_channels : int ,
10661099 dropout : float = 0.0 ,
10671100 num_layers : int = 1 ,
1068- transformer_layers_per_block : int = 1 ,
1101+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
10691102 resnet_eps : float = 1e-6 ,
10701103 resnet_time_scale_shift : str = "default" ,
10711104 resnet_act_fn : str = "swish" ,
@@ -1084,6 +1117,7 @@ def __init__(
10841117 temporal_cross_attention_dim : Optional [int ] = None ,
10851118 temporal_num_attention_heads : int = 8 ,
10861119 temporal_max_seq_length : int = 32 ,
1120+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
10871121 ):
10881122 super ().__init__ ()
10891123 resnets = []
@@ -1093,6 +1127,22 @@ def __init__(
10931127 self .has_cross_attention = True
10941128 self .num_attention_heads = num_attention_heads
10951129
1130+ # support for variable transformer layers per block
1131+ if isinstance (transformer_layers_per_block , int ):
1132+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1133+ elif len (transformer_layers_per_block ) != num_layers :
1134+ raise ValueError (
1135+ f"transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1136+ )
1137+
1138+ # support for variable transformer layers per temporal block
1139+ if isinstance (temporal_transformer_layers_per_block , int ):
1140+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1141+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1142+ raise ValueError (
1143+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1144+ )
1145+
10961146 for i in range (num_layers ):
10971147 in_channels = in_channels if i == 0 else out_channels
10981148 resnets .append (
@@ -1116,7 +1166,7 @@ def __init__(
11161166 num_attention_heads ,
11171167 out_channels // num_attention_heads ,
11181168 in_channels = out_channels ,
1119- num_layers = transformer_layers_per_block ,
1169+ num_layers = transformer_layers_per_block [ i ] ,
11201170 cross_attention_dim = cross_attention_dim ,
11211171 norm_num_groups = resnet_groups ,
11221172 use_linear_projection = use_linear_projection ,
@@ -1141,6 +1191,7 @@ def __init__(
11411191 TransformerTemporalModel (
11421192 num_attention_heads = temporal_num_attention_heads ,
11431193 in_channels = out_channels ,
1194+ num_layers = temporal_transformer_layers_per_block [i ],
11441195 norm_num_groups = resnet_groups ,
11451196 cross_attention_dim = temporal_cross_attention_dim ,
11461197 attention_bias = False ,
@@ -1257,7 +1308,7 @@ def __init__(
12571308 resolution_idx : Optional [int ] = None ,
12581309 dropout : float = 0.0 ,
12591310 num_layers : int = 1 ,
1260- transformer_layers_per_block : int = 1 ,
1311+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
12611312 resnet_eps : float = 1e-6 ,
12621313 resnet_time_scale_shift : str = "default" ,
12631314 resnet_act_fn : str = "swish" ,
@@ -1275,6 +1326,7 @@ def __init__(
12751326 temporal_cross_attention_dim : Optional [int ] = None ,
12761327 temporal_num_attention_heads : int = 8 ,
12771328 temporal_max_seq_length : int = 32 ,
1329+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
12781330 ):
12791331 super ().__init__ ()
12801332 resnets = []
@@ -1284,6 +1336,22 @@ def __init__(
12841336 self .has_cross_attention = True
12851337 self .num_attention_heads = num_attention_heads
12861338
1339+ # support for variable transformer layers per block
1340+ if isinstance (transformer_layers_per_block , int ):
1341+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1342+ elif len (transformer_layers_per_block ) != num_layers :
1343+ raise ValueError (
1344+ f"transformer_layers_per_block must be an integer or a list of integers of length { num_layers } , got { len (transformer_layers_per_block )} "
1345+ )
1346+
1347+ # support for variable transformer layers per temporal block
1348+ if isinstance (temporal_transformer_layers_per_block , int ):
1349+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1350+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1351+ raise ValueError (
1352+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } , got { len (temporal_transformer_layers_per_block )} "
1353+ )
1354+
12871355 for i in range (num_layers ):
12881356 res_skip_channels = in_channels if (i == num_layers - 1 ) else out_channels
12891357 resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1309,7 +1377,7 @@ def __init__(
13091377 num_attention_heads ,
13101378 out_channels // num_attention_heads ,
13111379 in_channels = out_channels ,
1312- num_layers = transformer_layers_per_block ,
1380+ num_layers = transformer_layers_per_block [ i ] ,
13131381 cross_attention_dim = cross_attention_dim ,
13141382 norm_num_groups = resnet_groups ,
13151383 use_linear_projection = use_linear_projection ,
@@ -1333,6 +1401,7 @@ def __init__(
13331401 TransformerTemporalModel (
13341402 num_attention_heads = temporal_num_attention_heads ,
13351403 in_channels = out_channels ,
1404+ num_layers = temporal_transformer_layers_per_block [i ],
13361405 norm_num_groups = resnet_groups ,
13371406 cross_attention_dim = temporal_cross_attention_dim ,
13381407 attention_bias = False ,
@@ -1467,11 +1536,20 @@ def __init__(
14671536 temporal_cross_attention_dim : Optional [int ] = None ,
14681537 temporal_num_attention_heads : int = 8 ,
14691538 temporal_max_seq_length : int = 32 ,
1539+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
14701540 ):
14711541 super ().__init__ ()
14721542 resnets = []
14731543 motion_modules = []
14741544
1545+ # support for variable transformer layers per temporal block
1546+ if isinstance (temporal_transformer_layers_per_block , int ):
1547+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1548+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1549+ raise ValueError (
1550+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1551+ )
1552+
14751553 for i in range (num_layers ):
14761554 res_skip_channels = in_channels if (i == num_layers - 1 ) else out_channels
14771555 resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1495,6 +1573,7 @@ def __init__(
14951573 TransformerTemporalModel (
14961574 num_attention_heads = temporal_num_attention_heads ,
14971575 in_channels = out_channels ,
1576+ num_layers = temporal_transformer_layers_per_block [i ],
14981577 norm_num_groups = temporal_norm_num_groups ,
14991578 cross_attention_dim = temporal_cross_attention_dim ,
15001579 attention_bias = False ,
@@ -1596,7 +1675,7 @@ def __init__(
15961675 temb_channels : int ,
15971676 dropout : float = 0.0 ,
15981677 num_layers : int = 1 ,
1599- transformer_layers_per_block : int = 1 ,
1678+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
16001679 resnet_eps : float = 1e-6 ,
16011680 resnet_time_scale_shift : str = "default" ,
16021681 resnet_act_fn : str = "swish" ,
@@ -1605,20 +1684,37 @@ def __init__(
16051684 num_attention_heads : int = 1 ,
16061685 output_scale_factor : float = 1.0 ,
16071686 cross_attention_dim : int = 1280 ,
1608- dual_cross_attention : float = False ,
1609- use_linear_projection : float = False ,
1610- upcast_attention : float = False ,
1687+ dual_cross_attention : bool = False ,
1688+ use_linear_projection : bool = False ,
1689+ upcast_attention : bool = False ,
16111690 attention_type : str = "default" ,
16121691 temporal_num_attention_heads : int = 1 ,
16131692 temporal_cross_attention_dim : Optional [int ] = None ,
16141693 temporal_max_seq_length : int = 32 ,
1694+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
16151695 ):
16161696 super ().__init__ ()
16171697
16181698 self .has_cross_attention = True
16191699 self .num_attention_heads = num_attention_heads
16201700 resnet_groups = resnet_groups if resnet_groups is not None else min (in_channels // 4 , 32 )
16211701
1702+ # support for variable transformer layers per block
1703+ if isinstance (transformer_layers_per_block , int ):
1704+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1705+ elif len (transformer_layers_per_block ) != num_layers :
1706+ raise ValueError (
1707+ f"`transformer_layers_per_block` should be an integer or a list of integers of length { num_layers } ."
1708+ )
1709+
1710+ # support for variable transformer layers per temporal block
1711+ if isinstance (temporal_transformer_layers_per_block , int ):
1712+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1713+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1714+ raise ValueError (
1715+ f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length { num_layers } ."
1716+ )
1717+
16221718 # there is always at least one resnet
16231719 resnets = [
16241720 ResnetBlock2D (
@@ -1637,14 +1733,14 @@ def __init__(
16371733 attentions = []
16381734 motion_modules = []
16391735
1640- for _ in range (num_layers ):
1736+ for i in range (num_layers ):
16411737 if not dual_cross_attention :
16421738 attentions .append (
16431739 Transformer2DModel (
16441740 num_attention_heads ,
16451741 in_channels // num_attention_heads ,
16461742 in_channels = in_channels ,
1647- num_layers = transformer_layers_per_block ,
1743+ num_layers = transformer_layers_per_block [ i ] ,
16481744 cross_attention_dim = cross_attention_dim ,
16491745 norm_num_groups = resnet_groups ,
16501746 use_linear_projection = use_linear_projection ,
@@ -1682,6 +1778,7 @@ def __init__(
16821778 num_attention_heads = temporal_num_attention_heads ,
16831779 attention_head_dim = in_channels // temporal_num_attention_heads ,
16841780 in_channels = in_channels ,
1781+ num_layers = temporal_transformer_layers_per_block [i ],
16851782 norm_num_groups = resnet_groups ,
16861783 cross_attention_dim = temporal_cross_attention_dim ,
16871784 attention_bias = False ,
0 commit comments