@@ -58,7 +58,9 @@ def get_down_block(
58
58
resnet_time_scale_shift : str = "default" ,
59
59
temporal_num_attention_heads : int = 8 ,
60
60
temporal_max_seq_length : int = 32 ,
61
- transformer_layers_per_block : int = 1 ,
61
+ transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
62
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
63
+ dropout : float = 0.0 ,
62
64
) -> Union [
63
65
"DownBlock3D" ,
64
66
"CrossAttnDownBlock3D" ,
@@ -79,6 +81,7 @@ def get_down_block(
79
81
resnet_groups = resnet_groups ,
80
82
downsample_padding = downsample_padding ,
81
83
resnet_time_scale_shift = resnet_time_scale_shift ,
84
+ dropout = dropout ,
82
85
)
83
86
elif down_block_type == "CrossAttnDownBlock3D" :
84
87
if cross_attention_dim is None :
@@ -100,6 +103,7 @@ def get_down_block(
100
103
only_cross_attention = only_cross_attention ,
101
104
upcast_attention = upcast_attention ,
102
105
resnet_time_scale_shift = resnet_time_scale_shift ,
106
+ dropout = dropout ,
103
107
)
104
108
if down_block_type == "DownBlockMotion" :
105
109
return DownBlockMotion (
@@ -115,6 +119,8 @@ def get_down_block(
115
119
resnet_time_scale_shift = resnet_time_scale_shift ,
116
120
temporal_num_attention_heads = temporal_num_attention_heads ,
117
121
temporal_max_seq_length = temporal_max_seq_length ,
122
+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
123
+ dropout = dropout ,
118
124
)
119
125
elif down_block_type == "CrossAttnDownBlockMotion" :
120
126
if cross_attention_dim is None :
@@ -139,6 +145,8 @@ def get_down_block(
139
145
resnet_time_scale_shift = resnet_time_scale_shift ,
140
146
temporal_num_attention_heads = temporal_num_attention_heads ,
141
147
temporal_max_seq_length = temporal_max_seq_length ,
148
+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
149
+ dropout = dropout ,
142
150
)
143
151
elif down_block_type == "DownBlockSpatioTemporal" :
144
152
# added for SDV
@@ -189,7 +197,8 @@ def get_up_block(
189
197
temporal_num_attention_heads : int = 8 ,
190
198
temporal_cross_attention_dim : Optional [int ] = None ,
191
199
temporal_max_seq_length : int = 32 ,
192
- transformer_layers_per_block : int = 1 ,
200
+ transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
201
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
193
202
dropout : float = 0.0 ,
194
203
) -> Union [
195
204
"UpBlock3D" ,
@@ -212,6 +221,7 @@ def get_up_block(
212
221
resnet_groups = resnet_groups ,
213
222
resnet_time_scale_shift = resnet_time_scale_shift ,
214
223
resolution_idx = resolution_idx ,
224
+ dropout = dropout ,
215
225
)
216
226
elif up_block_type == "CrossAttnUpBlock3D" :
217
227
if cross_attention_dim is None :
@@ -234,6 +244,7 @@ def get_up_block(
234
244
upcast_attention = upcast_attention ,
235
245
resnet_time_scale_shift = resnet_time_scale_shift ,
236
246
resolution_idx = resolution_idx ,
247
+ dropout = dropout ,
237
248
)
238
249
if up_block_type == "UpBlockMotion" :
239
250
return UpBlockMotion (
@@ -250,6 +261,8 @@ def get_up_block(
250
261
resolution_idx = resolution_idx ,
251
262
temporal_num_attention_heads = temporal_num_attention_heads ,
252
263
temporal_max_seq_length = temporal_max_seq_length ,
264
+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
265
+ dropout = dropout ,
253
266
)
254
267
elif up_block_type == "CrossAttnUpBlockMotion" :
255
268
if cross_attention_dim is None :
@@ -275,6 +288,8 @@ def get_up_block(
275
288
resolution_idx = resolution_idx ,
276
289
temporal_num_attention_heads = temporal_num_attention_heads ,
277
290
temporal_max_seq_length = temporal_max_seq_length ,
291
+ temporal_transformer_layers_per_block = temporal_transformer_layers_per_block ,
292
+ dropout = dropout ,
278
293
)
279
294
elif up_block_type == "UpBlockSpatioTemporal" :
280
295
# added for SDV
@@ -948,14 +963,31 @@ def __init__(
948
963
output_scale_factor : float = 1.0 ,
949
964
add_downsample : bool = True ,
950
965
downsample_padding : int = 1 ,
951
- temporal_num_attention_heads : int = 1 ,
966
+ temporal_num_attention_heads : Union [ int , Tuple [ int ]] = 1 ,
952
967
temporal_cross_attention_dim : Optional [int ] = None ,
953
968
temporal_max_seq_length : int = 32 ,
969
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
954
970
):
955
971
super ().__init__ ()
956
972
resnets = []
957
973
motion_modules = []
958
974
975
+ # support for variable transformer layers per temporal block
976
+ if isinstance (temporal_transformer_layers_per_block , int ):
977
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
978
+ elif len (temporal_transformer_layers_per_block ) != num_layers :
979
+ raise ValueError (
980
+ f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length { num_layers } "
981
+ )
982
+
983
+ # support for variable number of attention head per temporal layers
984
+ if isinstance (temporal_num_attention_heads , int ):
985
+ temporal_num_attention_heads = (temporal_num_attention_heads ,) * num_layers
986
+ elif len (temporal_num_attention_heads ) != num_layers :
987
+ raise ValueError (
988
+ f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length { num_layers } "
989
+ )
990
+
959
991
for i in range (num_layers ):
960
992
in_channels = in_channels if i == 0 else out_channels
961
993
resnets .append (
@@ -974,15 +1006,16 @@ def __init__(
974
1006
)
975
1007
motion_modules .append (
976
1008
TransformerTemporalModel (
977
- num_attention_heads = temporal_num_attention_heads ,
1009
+ num_attention_heads = temporal_num_attention_heads [ i ] ,
978
1010
in_channels = out_channels ,
1011
+ num_layers = temporal_transformer_layers_per_block [i ],
979
1012
norm_num_groups = resnet_groups ,
980
1013
cross_attention_dim = temporal_cross_attention_dim ,
981
1014
attention_bias = False ,
982
1015
activation_fn = "geglu" ,
983
1016
positional_embeddings = "sinusoidal" ,
984
1017
num_positional_embeddings = temporal_max_seq_length ,
985
- attention_head_dim = out_channels // temporal_num_attention_heads ,
1018
+ attention_head_dim = out_channels // temporal_num_attention_heads [ i ] ,
986
1019
)
987
1020
)
988
1021
@@ -1065,7 +1098,7 @@ def __init__(
1065
1098
temb_channels : int ,
1066
1099
dropout : float = 0.0 ,
1067
1100
num_layers : int = 1 ,
1068
- transformer_layers_per_block : int = 1 ,
1101
+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
1069
1102
resnet_eps : float = 1e-6 ,
1070
1103
resnet_time_scale_shift : str = "default" ,
1071
1104
resnet_act_fn : str = "swish" ,
@@ -1084,6 +1117,7 @@ def __init__(
1084
1117
temporal_cross_attention_dim : Optional [int ] = None ,
1085
1118
temporal_num_attention_heads : int = 8 ,
1086
1119
temporal_max_seq_length : int = 32 ,
1120
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
1087
1121
):
1088
1122
super ().__init__ ()
1089
1123
resnets = []
@@ -1093,6 +1127,22 @@ def __init__(
1093
1127
self .has_cross_attention = True
1094
1128
self .num_attention_heads = num_attention_heads
1095
1129
1130
+ # support for variable transformer layers per block
1131
+ if isinstance (transformer_layers_per_block , int ):
1132
+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1133
+ elif len (transformer_layers_per_block ) != num_layers :
1134
+ raise ValueError (
1135
+ f"transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1136
+ )
1137
+
1138
+ # support for variable transformer layers per temporal block
1139
+ if isinstance (temporal_transformer_layers_per_block , int ):
1140
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1141
+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1142
+ raise ValueError (
1143
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1144
+ )
1145
+
1096
1146
for i in range (num_layers ):
1097
1147
in_channels = in_channels if i == 0 else out_channels
1098
1148
resnets .append (
@@ -1116,7 +1166,7 @@ def __init__(
1116
1166
num_attention_heads ,
1117
1167
out_channels // num_attention_heads ,
1118
1168
in_channels = out_channels ,
1119
- num_layers = transformer_layers_per_block ,
1169
+ num_layers = transformer_layers_per_block [ i ] ,
1120
1170
cross_attention_dim = cross_attention_dim ,
1121
1171
norm_num_groups = resnet_groups ,
1122
1172
use_linear_projection = use_linear_projection ,
@@ -1141,6 +1191,7 @@ def __init__(
1141
1191
TransformerTemporalModel (
1142
1192
num_attention_heads = temporal_num_attention_heads ,
1143
1193
in_channels = out_channels ,
1194
+ num_layers = temporal_transformer_layers_per_block [i ],
1144
1195
norm_num_groups = resnet_groups ,
1145
1196
cross_attention_dim = temporal_cross_attention_dim ,
1146
1197
attention_bias = False ,
@@ -1257,7 +1308,7 @@ def __init__(
1257
1308
resolution_idx : Optional [int ] = None ,
1258
1309
dropout : float = 0.0 ,
1259
1310
num_layers : int = 1 ,
1260
- transformer_layers_per_block : int = 1 ,
1311
+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
1261
1312
resnet_eps : float = 1e-6 ,
1262
1313
resnet_time_scale_shift : str = "default" ,
1263
1314
resnet_act_fn : str = "swish" ,
@@ -1275,6 +1326,7 @@ def __init__(
1275
1326
temporal_cross_attention_dim : Optional [int ] = None ,
1276
1327
temporal_num_attention_heads : int = 8 ,
1277
1328
temporal_max_seq_length : int = 32 ,
1329
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
1278
1330
):
1279
1331
super ().__init__ ()
1280
1332
resnets = []
@@ -1284,6 +1336,22 @@ def __init__(
1284
1336
self .has_cross_attention = True
1285
1337
self .num_attention_heads = num_attention_heads
1286
1338
1339
+ # support for variable transformer layers per block
1340
+ if isinstance (transformer_layers_per_block , int ):
1341
+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1342
+ elif len (transformer_layers_per_block ) != num_layers :
1343
+ raise ValueError (
1344
+ f"transformer_layers_per_block must be an integer or a list of integers of length { num_layers } , got { len (transformer_layers_per_block )} "
1345
+ )
1346
+
1347
+ # support for variable transformer layers per temporal block
1348
+ if isinstance (temporal_transformer_layers_per_block , int ):
1349
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1350
+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1351
+ raise ValueError (
1352
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } , got { len (temporal_transformer_layers_per_block )} "
1353
+ )
1354
+
1287
1355
for i in range (num_layers ):
1288
1356
res_skip_channels = in_channels if (i == num_layers - 1 ) else out_channels
1289
1357
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1309,7 +1377,7 @@ def __init__(
1309
1377
num_attention_heads ,
1310
1378
out_channels // num_attention_heads ,
1311
1379
in_channels = out_channels ,
1312
- num_layers = transformer_layers_per_block ,
1380
+ num_layers = transformer_layers_per_block [ i ] ,
1313
1381
cross_attention_dim = cross_attention_dim ,
1314
1382
norm_num_groups = resnet_groups ,
1315
1383
use_linear_projection = use_linear_projection ,
@@ -1333,6 +1401,7 @@ def __init__(
1333
1401
TransformerTemporalModel (
1334
1402
num_attention_heads = temporal_num_attention_heads ,
1335
1403
in_channels = out_channels ,
1404
+ num_layers = temporal_transformer_layers_per_block [i ],
1336
1405
norm_num_groups = resnet_groups ,
1337
1406
cross_attention_dim = temporal_cross_attention_dim ,
1338
1407
attention_bias = False ,
@@ -1467,11 +1536,20 @@ def __init__(
1467
1536
temporal_cross_attention_dim : Optional [int ] = None ,
1468
1537
temporal_num_attention_heads : int = 8 ,
1469
1538
temporal_max_seq_length : int = 32 ,
1539
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
1470
1540
):
1471
1541
super ().__init__ ()
1472
1542
resnets = []
1473
1543
motion_modules = []
1474
1544
1545
+ # support for variable transformer layers per temporal block
1546
+ if isinstance (temporal_transformer_layers_per_block , int ):
1547
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1548
+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1549
+ raise ValueError (
1550
+ f"temporal_transformer_layers_per_block must be an integer or a list of integers of length { num_layers } "
1551
+ )
1552
+
1475
1553
for i in range (num_layers ):
1476
1554
res_skip_channels = in_channels if (i == num_layers - 1 ) else out_channels
1477
1555
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1495,6 +1573,7 @@ def __init__(
1495
1573
TransformerTemporalModel (
1496
1574
num_attention_heads = temporal_num_attention_heads ,
1497
1575
in_channels = out_channels ,
1576
+ num_layers = temporal_transformer_layers_per_block [i ],
1498
1577
norm_num_groups = temporal_norm_num_groups ,
1499
1578
cross_attention_dim = temporal_cross_attention_dim ,
1500
1579
attention_bias = False ,
@@ -1596,7 +1675,7 @@ def __init__(
1596
1675
temb_channels : int ,
1597
1676
dropout : float = 0.0 ,
1598
1677
num_layers : int = 1 ,
1599
- transformer_layers_per_block : int = 1 ,
1678
+ transformer_layers_per_block : Union [ int , Tuple [ int ]] = 1 ,
1600
1679
resnet_eps : float = 1e-6 ,
1601
1680
resnet_time_scale_shift : str = "default" ,
1602
1681
resnet_act_fn : str = "swish" ,
@@ -1605,20 +1684,37 @@ def __init__(
1605
1684
num_attention_heads : int = 1 ,
1606
1685
output_scale_factor : float = 1.0 ,
1607
1686
cross_attention_dim : int = 1280 ,
1608
- dual_cross_attention : float = False ,
1609
- use_linear_projection : float = False ,
1610
- upcast_attention : float = False ,
1687
+ dual_cross_attention : bool = False ,
1688
+ use_linear_projection : bool = False ,
1689
+ upcast_attention : bool = False ,
1611
1690
attention_type : str = "default" ,
1612
1691
temporal_num_attention_heads : int = 1 ,
1613
1692
temporal_cross_attention_dim : Optional [int ] = None ,
1614
1693
temporal_max_seq_length : int = 32 ,
1694
+ temporal_transformer_layers_per_block : Union [int , Tuple [int ]] = 1 ,
1615
1695
):
1616
1696
super ().__init__ ()
1617
1697
1618
1698
self .has_cross_attention = True
1619
1699
self .num_attention_heads = num_attention_heads
1620
1700
resnet_groups = resnet_groups if resnet_groups is not None else min (in_channels // 4 , 32 )
1621
1701
1702
+ # support for variable transformer layers per block
1703
+ if isinstance (transformer_layers_per_block , int ):
1704
+ transformer_layers_per_block = (transformer_layers_per_block ,) * num_layers
1705
+ elif len (transformer_layers_per_block ) != num_layers :
1706
+ raise ValueError (
1707
+ f"`transformer_layers_per_block` should be an integer or a list of integers of length { num_layers } ."
1708
+ )
1709
+
1710
+ # support for variable transformer layers per temporal block
1711
+ if isinstance (temporal_transformer_layers_per_block , int ):
1712
+ temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block ,) * num_layers
1713
+ elif len (temporal_transformer_layers_per_block ) != num_layers :
1714
+ raise ValueError (
1715
+ f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length { num_layers } ."
1716
+ )
1717
+
1622
1718
# there is always at least one resnet
1623
1719
resnets = [
1624
1720
ResnetBlock2D (
@@ -1637,14 +1733,14 @@ def __init__(
1637
1733
attentions = []
1638
1734
motion_modules = []
1639
1735
1640
- for _ in range (num_layers ):
1736
+ for i in range (num_layers ):
1641
1737
if not dual_cross_attention :
1642
1738
attentions .append (
1643
1739
Transformer2DModel (
1644
1740
num_attention_heads ,
1645
1741
in_channels // num_attention_heads ,
1646
1742
in_channels = in_channels ,
1647
- num_layers = transformer_layers_per_block ,
1743
+ num_layers = transformer_layers_per_block [ i ] ,
1648
1744
cross_attention_dim = cross_attention_dim ,
1649
1745
norm_num_groups = resnet_groups ,
1650
1746
use_linear_projection = use_linear_projection ,
@@ -1682,6 +1778,7 @@ def __init__(
1682
1778
num_attention_heads = temporal_num_attention_heads ,
1683
1779
attention_head_dim = in_channels // temporal_num_attention_heads ,
1684
1780
in_channels = in_channels ,
1781
+ num_layers = temporal_transformer_layers_per_block [i ],
1685
1782
norm_num_groups = resnet_groups ,
1686
1783
cross_attention_dim = temporal_cross_attention_dim ,
1687
1784
attention_bias = False ,
0 commit comments