Skip to content

Commit 3e0d128

Browse files
authored
Motion Model / Adapter versatility (#8301)
* Motion Model / Adapter versatility - allow to use a different number of layers per block - allow to use a different number of transformer per layers per block - allow a different number of motion attention head per block - use dropout argument in get_down/up_block in 3d blocks * Motion Model added arguments renamed & refactoring * Add test for asymmetric UNetMotionModel
1 parent a536e77 commit 3e0d128

File tree

3 files changed

+291
-38
lines changed

3 files changed

+291
-38
lines changed

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 112 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def get_down_block(
5858
resnet_time_scale_shift: str = "default",
5959
temporal_num_attention_heads: int = 8,
6060
temporal_max_seq_length: int = 32,
61-
transformer_layers_per_block: int = 1,
61+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
62+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
63+
dropout: float = 0.0,
6264
) -> Union[
6365
"DownBlock3D",
6466
"CrossAttnDownBlock3D",
@@ -79,6 +81,7 @@ def get_down_block(
7981
resnet_groups=resnet_groups,
8082
downsample_padding=downsample_padding,
8183
resnet_time_scale_shift=resnet_time_scale_shift,
84+
dropout=dropout,
8285
)
8386
elif down_block_type == "CrossAttnDownBlock3D":
8487
if cross_attention_dim is None:
@@ -100,6 +103,7 @@ def get_down_block(
100103
only_cross_attention=only_cross_attention,
101104
upcast_attention=upcast_attention,
102105
resnet_time_scale_shift=resnet_time_scale_shift,
106+
dropout=dropout,
103107
)
104108
if down_block_type == "DownBlockMotion":
105109
return DownBlockMotion(
@@ -115,6 +119,8 @@ def get_down_block(
115119
resnet_time_scale_shift=resnet_time_scale_shift,
116120
temporal_num_attention_heads=temporal_num_attention_heads,
117121
temporal_max_seq_length=temporal_max_seq_length,
122+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
123+
dropout=dropout,
118124
)
119125
elif down_block_type == "CrossAttnDownBlockMotion":
120126
if cross_attention_dim is None:
@@ -139,6 +145,8 @@ def get_down_block(
139145
resnet_time_scale_shift=resnet_time_scale_shift,
140146
temporal_num_attention_heads=temporal_num_attention_heads,
141147
temporal_max_seq_length=temporal_max_seq_length,
148+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
149+
dropout=dropout,
142150
)
143151
elif down_block_type == "DownBlockSpatioTemporal":
144152
# added for SDV
@@ -189,7 +197,8 @@ def get_up_block(
189197
temporal_num_attention_heads: int = 8,
190198
temporal_cross_attention_dim: Optional[int] = None,
191199
temporal_max_seq_length: int = 32,
192-
transformer_layers_per_block: int = 1,
200+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
201+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
193202
dropout: float = 0.0,
194203
) -> Union[
195204
"UpBlock3D",
@@ -212,6 +221,7 @@ def get_up_block(
212221
resnet_groups=resnet_groups,
213222
resnet_time_scale_shift=resnet_time_scale_shift,
214223
resolution_idx=resolution_idx,
224+
dropout=dropout,
215225
)
216226
elif up_block_type == "CrossAttnUpBlock3D":
217227
if cross_attention_dim is None:
@@ -234,6 +244,7 @@ def get_up_block(
234244
upcast_attention=upcast_attention,
235245
resnet_time_scale_shift=resnet_time_scale_shift,
236246
resolution_idx=resolution_idx,
247+
dropout=dropout,
237248
)
238249
if up_block_type == "UpBlockMotion":
239250
return UpBlockMotion(
@@ -250,6 +261,8 @@ def get_up_block(
250261
resolution_idx=resolution_idx,
251262
temporal_num_attention_heads=temporal_num_attention_heads,
252263
temporal_max_seq_length=temporal_max_seq_length,
264+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
265+
dropout=dropout,
253266
)
254267
elif up_block_type == "CrossAttnUpBlockMotion":
255268
if cross_attention_dim is None:
@@ -275,6 +288,8 @@ def get_up_block(
275288
resolution_idx=resolution_idx,
276289
temporal_num_attention_heads=temporal_num_attention_heads,
277290
temporal_max_seq_length=temporal_max_seq_length,
291+
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
292+
dropout=dropout,
278293
)
279294
elif up_block_type == "UpBlockSpatioTemporal":
280295
# added for SDV
@@ -948,14 +963,31 @@ def __init__(
948963
output_scale_factor: float = 1.0,
949964
add_downsample: bool = True,
950965
downsample_padding: int = 1,
951-
temporal_num_attention_heads: int = 1,
966+
temporal_num_attention_heads: Union[int, Tuple[int]] = 1,
952967
temporal_cross_attention_dim: Optional[int] = None,
953968
temporal_max_seq_length: int = 32,
969+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
954970
):
955971
super().__init__()
956972
resnets = []
957973
motion_modules = []
958974

975+
# support for variable transformer layers per temporal block
976+
if isinstance(temporal_transformer_layers_per_block, int):
977+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
978+
elif len(temporal_transformer_layers_per_block) != num_layers:
979+
raise ValueError(
980+
f"`temporal_transformer_layers_per_block` must be an integer or a tuple of integers of length {num_layers}"
981+
)
982+
983+
# support for variable number of attention head per temporal layers
984+
if isinstance(temporal_num_attention_heads, int):
985+
temporal_num_attention_heads = (temporal_num_attention_heads,) * num_layers
986+
elif len(temporal_num_attention_heads) != num_layers:
987+
raise ValueError(
988+
f"`temporal_num_attention_heads` must be an integer or a tuple of integers of length {num_layers}"
989+
)
990+
959991
for i in range(num_layers):
960992
in_channels = in_channels if i == 0 else out_channels
961993
resnets.append(
@@ -974,15 +1006,16 @@ def __init__(
9741006
)
9751007
motion_modules.append(
9761008
TransformerTemporalModel(
977-
num_attention_heads=temporal_num_attention_heads,
1009+
num_attention_heads=temporal_num_attention_heads[i],
9781010
in_channels=out_channels,
1011+
num_layers=temporal_transformer_layers_per_block[i],
9791012
norm_num_groups=resnet_groups,
9801013
cross_attention_dim=temporal_cross_attention_dim,
9811014
attention_bias=False,
9821015
activation_fn="geglu",
9831016
positional_embeddings="sinusoidal",
9841017
num_positional_embeddings=temporal_max_seq_length,
985-
attention_head_dim=out_channels // temporal_num_attention_heads,
1018+
attention_head_dim=out_channels // temporal_num_attention_heads[i],
9861019
)
9871020
)
9881021

@@ -1065,7 +1098,7 @@ def __init__(
10651098
temb_channels: int,
10661099
dropout: float = 0.0,
10671100
num_layers: int = 1,
1068-
transformer_layers_per_block: int = 1,
1101+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
10691102
resnet_eps: float = 1e-6,
10701103
resnet_time_scale_shift: str = "default",
10711104
resnet_act_fn: str = "swish",
@@ -1084,6 +1117,7 @@ def __init__(
10841117
temporal_cross_attention_dim: Optional[int] = None,
10851118
temporal_num_attention_heads: int = 8,
10861119
temporal_max_seq_length: int = 32,
1120+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
10871121
):
10881122
super().__init__()
10891123
resnets = []
@@ -1093,6 +1127,22 @@ def __init__(
10931127
self.has_cross_attention = True
10941128
self.num_attention_heads = num_attention_heads
10951129

1130+
# support for variable transformer layers per block
1131+
if isinstance(transformer_layers_per_block, int):
1132+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
1133+
elif len(transformer_layers_per_block) != num_layers:
1134+
raise ValueError(
1135+
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
1136+
)
1137+
1138+
# support for variable transformer layers per temporal block
1139+
if isinstance(temporal_transformer_layers_per_block, int):
1140+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
1141+
elif len(temporal_transformer_layers_per_block) != num_layers:
1142+
raise ValueError(
1143+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
1144+
)
1145+
10961146
for i in range(num_layers):
10971147
in_channels = in_channels if i == 0 else out_channels
10981148
resnets.append(
@@ -1116,7 +1166,7 @@ def __init__(
11161166
num_attention_heads,
11171167
out_channels // num_attention_heads,
11181168
in_channels=out_channels,
1119-
num_layers=transformer_layers_per_block,
1169+
num_layers=transformer_layers_per_block[i],
11201170
cross_attention_dim=cross_attention_dim,
11211171
norm_num_groups=resnet_groups,
11221172
use_linear_projection=use_linear_projection,
@@ -1141,6 +1191,7 @@ def __init__(
11411191
TransformerTemporalModel(
11421192
num_attention_heads=temporal_num_attention_heads,
11431193
in_channels=out_channels,
1194+
num_layers=temporal_transformer_layers_per_block[i],
11441195
norm_num_groups=resnet_groups,
11451196
cross_attention_dim=temporal_cross_attention_dim,
11461197
attention_bias=False,
@@ -1257,7 +1308,7 @@ def __init__(
12571308
resolution_idx: Optional[int] = None,
12581309
dropout: float = 0.0,
12591310
num_layers: int = 1,
1260-
transformer_layers_per_block: int = 1,
1311+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
12611312
resnet_eps: float = 1e-6,
12621313
resnet_time_scale_shift: str = "default",
12631314
resnet_act_fn: str = "swish",
@@ -1275,6 +1326,7 @@ def __init__(
12751326
temporal_cross_attention_dim: Optional[int] = None,
12761327
temporal_num_attention_heads: int = 8,
12771328
temporal_max_seq_length: int = 32,
1329+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
12781330
):
12791331
super().__init__()
12801332
resnets = []
@@ -1284,6 +1336,22 @@ def __init__(
12841336
self.has_cross_attention = True
12851337
self.num_attention_heads = num_attention_heads
12861338

1339+
# support for variable transformer layers per block
1340+
if isinstance(transformer_layers_per_block, int):
1341+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
1342+
elif len(transformer_layers_per_block) != num_layers:
1343+
raise ValueError(
1344+
f"transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(transformer_layers_per_block)}"
1345+
)
1346+
1347+
# support for variable transformer layers per temporal block
1348+
if isinstance(temporal_transformer_layers_per_block, int):
1349+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
1350+
elif len(temporal_transformer_layers_per_block) != num_layers:
1351+
raise ValueError(
1352+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}, got {len(temporal_transformer_layers_per_block)}"
1353+
)
1354+
12871355
for i in range(num_layers):
12881356
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
12891357
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1309,7 +1377,7 @@ def __init__(
13091377
num_attention_heads,
13101378
out_channels // num_attention_heads,
13111379
in_channels=out_channels,
1312-
num_layers=transformer_layers_per_block,
1380+
num_layers=transformer_layers_per_block[i],
13131381
cross_attention_dim=cross_attention_dim,
13141382
norm_num_groups=resnet_groups,
13151383
use_linear_projection=use_linear_projection,
@@ -1333,6 +1401,7 @@ def __init__(
13331401
TransformerTemporalModel(
13341402
num_attention_heads=temporal_num_attention_heads,
13351403
in_channels=out_channels,
1404+
num_layers=temporal_transformer_layers_per_block[i],
13361405
norm_num_groups=resnet_groups,
13371406
cross_attention_dim=temporal_cross_attention_dim,
13381407
attention_bias=False,
@@ -1467,11 +1536,20 @@ def __init__(
14671536
temporal_cross_attention_dim: Optional[int] = None,
14681537
temporal_num_attention_heads: int = 8,
14691538
temporal_max_seq_length: int = 32,
1539+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
14701540
):
14711541
super().__init__()
14721542
resnets = []
14731543
motion_modules = []
14741544

1545+
# support for variable transformer layers per temporal block
1546+
if isinstance(temporal_transformer_layers_per_block, int):
1547+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
1548+
elif len(temporal_transformer_layers_per_block) != num_layers:
1549+
raise ValueError(
1550+
f"temporal_transformer_layers_per_block must be an integer or a list of integers of length {num_layers}"
1551+
)
1552+
14751553
for i in range(num_layers):
14761554
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
14771555
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1495,6 +1573,7 @@ def __init__(
14951573
TransformerTemporalModel(
14961574
num_attention_heads=temporal_num_attention_heads,
14971575
in_channels=out_channels,
1576+
num_layers=temporal_transformer_layers_per_block[i],
14981577
norm_num_groups=temporal_norm_num_groups,
14991578
cross_attention_dim=temporal_cross_attention_dim,
15001579
attention_bias=False,
@@ -1596,7 +1675,7 @@ def __init__(
15961675
temb_channels: int,
15971676
dropout: float = 0.0,
15981677
num_layers: int = 1,
1599-
transformer_layers_per_block: int = 1,
1678+
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
16001679
resnet_eps: float = 1e-6,
16011680
resnet_time_scale_shift: str = "default",
16021681
resnet_act_fn: str = "swish",
@@ -1605,20 +1684,37 @@ def __init__(
16051684
num_attention_heads: int = 1,
16061685
output_scale_factor: float = 1.0,
16071686
cross_attention_dim: int = 1280,
1608-
dual_cross_attention: float = False,
1609-
use_linear_projection: float = False,
1610-
upcast_attention: float = False,
1687+
dual_cross_attention: bool = False,
1688+
use_linear_projection: bool = False,
1689+
upcast_attention: bool = False,
16111690
attention_type: str = "default",
16121691
temporal_num_attention_heads: int = 1,
16131692
temporal_cross_attention_dim: Optional[int] = None,
16141693
temporal_max_seq_length: int = 32,
1694+
temporal_transformer_layers_per_block: Union[int, Tuple[int]] = 1,
16151695
):
16161696
super().__init__()
16171697

16181698
self.has_cross_attention = True
16191699
self.num_attention_heads = num_attention_heads
16201700
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
16211701

1702+
# support for variable transformer layers per block
1703+
if isinstance(transformer_layers_per_block, int):
1704+
transformer_layers_per_block = (transformer_layers_per_block,) * num_layers
1705+
elif len(transformer_layers_per_block) != num_layers:
1706+
raise ValueError(
1707+
f"`transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
1708+
)
1709+
1710+
# support for variable transformer layers per temporal block
1711+
if isinstance(temporal_transformer_layers_per_block, int):
1712+
temporal_transformer_layers_per_block = (temporal_transformer_layers_per_block,) * num_layers
1713+
elif len(temporal_transformer_layers_per_block) != num_layers:
1714+
raise ValueError(
1715+
f"`temporal_transformer_layers_per_block` should be an integer or a list of integers of length {num_layers}."
1716+
)
1717+
16221718
# there is always at least one resnet
16231719
resnets = [
16241720
ResnetBlock2D(
@@ -1637,14 +1733,14 @@ def __init__(
16371733
attentions = []
16381734
motion_modules = []
16391735

1640-
for _ in range(num_layers):
1736+
for i in range(num_layers):
16411737
if not dual_cross_attention:
16421738
attentions.append(
16431739
Transformer2DModel(
16441740
num_attention_heads,
16451741
in_channels // num_attention_heads,
16461742
in_channels=in_channels,
1647-
num_layers=transformer_layers_per_block,
1743+
num_layers=transformer_layers_per_block[i],
16481744
cross_attention_dim=cross_attention_dim,
16491745
norm_num_groups=resnet_groups,
16501746
use_linear_projection=use_linear_projection,
@@ -1682,6 +1778,7 @@ def __init__(
16821778
num_attention_heads=temporal_num_attention_heads,
16831779
attention_head_dim=in_channels // temporal_num_attention_heads,
16841780
in_channels=in_channels,
1781+
num_layers=temporal_transformer_layers_per_block[i],
16851782
norm_num_groups=resnet_groups,
16861783
cross_attention_dim=temporal_cross_attention_dim,
16871784
attention_bias=False,

0 commit comments

Comments
 (0)