Skip to content

Commit cecdd8b

Browse files
authored
Adapt UNet2D for supre-resolution (#1385)
* allow disabling self attention * add class_embedding * fix copies * fix condition * fix copies * do_self_attention -> only_cross_attention * fix copies * num_classes -> num_class_embeds * fix default value
1 parent 30f6f44 commit cecdd8b

File tree

4 files changed

+60
-1
lines changed

4 files changed

+60
-1
lines changed

src/diffusers/models/attention.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
activation_fn: str = "geglu",
101101
num_embeds_ada_norm: Optional[int] = None,
102102
use_linear_projection: bool = False,
103+
only_cross_attention: bool = False,
103104
):
104105
super().__init__()
105106
self.use_linear_projection = use_linear_projection
@@ -157,6 +158,7 @@ def __init__(
157158
activation_fn=activation_fn,
158159
num_embeds_ada_norm=num_embeds_ada_norm,
159160
attention_bias=attention_bias,
161+
only_cross_attention=only_cross_attention,
160162
)
161163
for d in range(num_layers)
162164
]
@@ -387,14 +389,17 @@ def __init__(
387389
activation_fn: str = "geglu",
388390
num_embeds_ada_norm: Optional[int] = None,
389391
attention_bias: bool = False,
392+
only_cross_attention: bool = False,
390393
):
391394
super().__init__()
395+
self.only_cross_attention = only_cross_attention
392396
self.attn1 = CrossAttention(
393397
query_dim=dim,
394398
heads=num_attention_heads,
395399
dim_head=attention_head_dim,
396400
dropout=dropout,
397401
bias=attention_bias,
402+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
398403
) # is a self-attention
399404
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
400405
self.attn2 = CrossAttention(
@@ -461,7 +466,11 @@ def forward(self, hidden_states, context=None, timestep=None):
461466
norm_hidden_states = (
462467
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
463468
)
464-
hidden_states = self.attn1(norm_hidden_states) + hidden_states
469+
470+
if self.only_cross_attention:
471+
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
472+
else:
473+
hidden_states = self.attn1(norm_hidden_states) + hidden_states
465474

466475
# 2. Cross-Attention
467476
norm_hidden_states = (

src/diffusers/models/unet_2d_blocks.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def get_down_block(
3434
downsample_padding=None,
3535
dual_cross_attention=False,
3636
use_linear_projection=False,
37+
only_cross_attention=False,
3738
):
3839
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
3940
if down_block_type == "DownBlock2D":
@@ -78,6 +79,7 @@ def get_down_block(
7879
attn_num_head_channels=attn_num_head_channels,
7980
dual_cross_attention=dual_cross_attention,
8081
use_linear_projection=use_linear_projection,
82+
only_cross_attention=only_cross_attention,
8183
)
8284
elif down_block_type == "SkipDownBlock2D":
8385
return SkipDownBlock2D(
@@ -143,6 +145,7 @@ def get_up_block(
143145
cross_attention_dim=None,
144146
dual_cross_attention=False,
145147
use_linear_projection=False,
148+
only_cross_attention=False,
146149
):
147150
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
148151
if up_block_type == "UpBlock2D":
@@ -174,6 +177,7 @@ def get_up_block(
174177
attn_num_head_channels=attn_num_head_channels,
175178
dual_cross_attention=dual_cross_attention,
176179
use_linear_projection=use_linear_projection,
180+
only_cross_attention=only_cross_attention,
177181
)
178182
elif up_block_type == "AttnUpBlock2D":
179183
return AttnUpBlock2D(
@@ -530,6 +534,7 @@ def __init__(
530534
add_downsample=True,
531535
dual_cross_attention=False,
532536
use_linear_projection=False,
537+
only_cross_attention=False,
533538
):
534539
super().__init__()
535540
resnets = []
@@ -564,6 +569,7 @@ def __init__(
564569
cross_attention_dim=cross_attention_dim,
565570
norm_num_groups=resnet_groups,
566571
use_linear_projection=use_linear_projection,
572+
only_cross_attention=only_cross_attention,
567573
)
568574
)
569575
else:
@@ -1129,6 +1135,7 @@ def __init__(
11291135
add_upsample=True,
11301136
dual_cross_attention=False,
11311137
use_linear_projection=False,
1138+
only_cross_attention=False,
11321139
):
11331140
super().__init__()
11341141
resnets = []
@@ -1165,6 +1172,7 @@ def __init__(
11651172
cross_attention_dim=cross_attention_dim,
11661173
norm_num_groups=resnet_groups,
11671174
use_linear_projection=use_linear_projection,
1175+
only_cross_attention=only_cross_attention,
11681176
)
11691177
)
11701178
else:

src/diffusers/models/unet_2d_condition.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
"DownBlock2D",
9999
),
100100
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
101+
only_cross_attention: Union[bool, Tuple[bool]] = False,
101102
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
102103
layers_per_block: int = 2,
103104
downsample_padding: int = 1,
@@ -109,6 +110,7 @@ def __init__(
109110
attention_head_dim: Union[int, Tuple[int]] = 8,
110111
dual_cross_attention: bool = False,
111112
use_linear_projection: bool = False,
113+
num_class_embeds: Optional[int] = None,
112114
):
113115
super().__init__()
114116

@@ -124,10 +126,17 @@ def __init__(
124126

125127
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
126128

129+
# class embedding
130+
if num_class_embeds is not None:
131+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
132+
127133
self.down_blocks = nn.ModuleList([])
128134
self.mid_block = None
129135
self.up_blocks = nn.ModuleList([])
130136

137+
if isinstance(only_cross_attention, bool):
138+
only_cross_attention = [only_cross_attention] * len(down_block_types)
139+
131140
if isinstance(attention_head_dim, int):
132141
attention_head_dim = (attention_head_dim,) * len(down_block_types)
133142

@@ -153,6 +162,7 @@ def __init__(
153162
downsample_padding=downsample_padding,
154163
dual_cross_attention=dual_cross_attention,
155164
use_linear_projection=use_linear_projection,
165+
only_cross_attention=only_cross_attention[i],
156166
)
157167
self.down_blocks.append(down_block)
158168

@@ -177,6 +187,7 @@ def __init__(
177187
# up
178188
reversed_block_out_channels = list(reversed(block_out_channels))
179189
reversed_attention_head_dim = list(reversed(attention_head_dim))
190+
only_cross_attention = list(reversed(only_cross_attention))
180191
output_channel = reversed_block_out_channels[0]
181192
for i, up_block_type in enumerate(up_block_types):
182193
is_final_block = i == len(block_out_channels) - 1
@@ -207,6 +218,7 @@ def __init__(
207218
attn_num_head_channels=reversed_attention_head_dim[i],
208219
dual_cross_attention=dual_cross_attention,
209220
use_linear_projection=use_linear_projection,
221+
only_cross_attention=only_cross_attention[i],
210222
)
211223
self.up_blocks.append(up_block)
212224
prev_output_channel = output_channel
@@ -258,6 +270,7 @@ def forward(
258270
sample: torch.FloatTensor,
259271
timestep: Union[torch.Tensor, float, int],
260272
encoder_hidden_states: torch.Tensor,
273+
class_labels: Optional[torch.Tensor] = None,
261274
return_dict: bool = True,
262275
) -> Union[UNet2DConditionOutput, Tuple]:
263276
r"""
@@ -310,6 +323,12 @@ def forward(
310323
t_emb = t_emb.to(dtype=self.dtype)
311324
emb = self.time_embedding(t_emb)
312325

326+
if self.config.num_class_embeds is not None:
327+
if class_labels is None:
328+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
329+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
330+
emb = emb + class_emb
331+
313332
# 2. pre-process
314333
sample = self.conv_in(sample)
315334

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def __init__(
166166
"CrossAttnUpBlockFlat",
167167
"CrossAttnUpBlockFlat",
168168
),
169+
only_cross_attention: Union[bool, Tuple[bool]] = False,
169170
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
170171
layers_per_block: int = 2,
171172
downsample_padding: int = 1,
@@ -177,6 +178,7 @@ def __init__(
177178
attention_head_dim: Union[int, Tuple[int]] = 8,
178179
dual_cross_attention: bool = False,
179180
use_linear_projection: bool = False,
181+
num_class_embeds: Optional[int] = None,
180182
):
181183
super().__init__()
182184

@@ -192,10 +194,17 @@ def __init__(
192194

193195
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
194196

197+
# class embedding
198+
if num_class_embeds is not None:
199+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
200+
195201
self.down_blocks = nn.ModuleList([])
196202
self.mid_block = None
197203
self.up_blocks = nn.ModuleList([])
198204

205+
if isinstance(only_cross_attention, bool):
206+
only_cross_attention = [only_cross_attention] * len(down_block_types)
207+
199208
if isinstance(attention_head_dim, int):
200209
attention_head_dim = (attention_head_dim,) * len(down_block_types)
201210

@@ -221,6 +230,7 @@ def __init__(
221230
downsample_padding=downsample_padding,
222231
dual_cross_attention=dual_cross_attention,
223232
use_linear_projection=use_linear_projection,
233+
only_cross_attention=only_cross_attention[i],
224234
)
225235
self.down_blocks.append(down_block)
226236

@@ -245,6 +255,7 @@ def __init__(
245255
# up
246256
reversed_block_out_channels = list(reversed(block_out_channels))
247257
reversed_attention_head_dim = list(reversed(attention_head_dim))
258+
only_cross_attention = list(reversed(only_cross_attention))
248259
output_channel = reversed_block_out_channels[0]
249260
for i, up_block_type in enumerate(up_block_types):
250261
is_final_block = i == len(block_out_channels) - 1
@@ -275,6 +286,7 @@ def __init__(
275286
attn_num_head_channels=reversed_attention_head_dim[i],
276287
dual_cross_attention=dual_cross_attention,
277288
use_linear_projection=use_linear_projection,
289+
only_cross_attention=only_cross_attention[i],
278290
)
279291
self.up_blocks.append(up_block)
280292
prev_output_channel = output_channel
@@ -326,6 +338,7 @@ def forward(
326338
sample: torch.FloatTensor,
327339
timestep: Union[torch.Tensor, float, int],
328340
encoder_hidden_states: torch.Tensor,
341+
class_labels: Optional[torch.Tensor] = None,
329342
return_dict: bool = True,
330343
) -> Union[UNet2DConditionOutput, Tuple]:
331344
r"""
@@ -378,6 +391,12 @@ def forward(
378391
t_emb = t_emb.to(dtype=self.dtype)
379392
emb = self.time_embedding(t_emb)
380393

394+
if self.config.num_class_embeds is not None:
395+
if class_labels is None:
396+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
397+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
398+
emb = emb + class_emb
399+
381400
# 2. pre-process
382401
sample = self.conv_in(sample)
383402

@@ -648,6 +667,7 @@ def __init__(
648667
add_downsample=True,
649668
dual_cross_attention=False,
650669
use_linear_projection=False,
670+
only_cross_attention=False,
651671
):
652672
super().__init__()
653673
resnets = []
@@ -682,6 +702,7 @@ def __init__(
682702
cross_attention_dim=cross_attention_dim,
683703
norm_num_groups=resnet_groups,
684704
use_linear_projection=use_linear_projection,
705+
only_cross_attention=only_cross_attention,
685706
)
686707
)
687708
else:
@@ -861,6 +882,7 @@ def __init__(
861882
add_upsample=True,
862883
dual_cross_attention=False,
863884
use_linear_projection=False,
885+
only_cross_attention=False,
864886
):
865887
super().__init__()
866888
resnets = []
@@ -897,6 +919,7 @@ def __init__(
897919
cross_attention_dim=cross_attention_dim,
898920
norm_num_groups=resnet_groups,
899921
use_linear_projection=use_linear_projection,
922+
only_cross_attention=only_cross_attention,
900923
)
901924
)
902925
else:

0 commit comments

Comments
 (0)