Skip to content

Commit f07a16e

Browse files
authored
update unet2d (#1376)
* boom boom * remove duplicate arg * add use_linear_proj arg * fix copies * style * add fast tests * use_linear_proj -> use_linear_projection
1 parent 16a32c9 commit f07a16e

File tree

5 files changed

+110
-21
lines changed

5 files changed

+110
-21
lines changed

src/diffusers/models/attention.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,10 @@ def __init__(
9999
num_vector_embeds: Optional[int] = None,
100100
activation_fn: str = "geglu",
101101
num_embeds_ada_norm: Optional[int] = None,
102+
use_linear_projection: bool = False,
102103
):
103104
super().__init__()
105+
self.use_linear_projection = use_linear_projection
104106
self.num_attention_heads = num_attention_heads
105107
self.attention_head_dim = attention_head_dim
106108
inner_dim = num_attention_heads * attention_head_dim
@@ -126,7 +128,10 @@ def __init__(
126128
self.in_channels = in_channels
127129

128130
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
129-
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
131+
if use_linear_projection:
132+
self.proj_in = nn.Linear(in_channels, inner_dim)
133+
else:
134+
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
130135
elif self.is_input_vectorized:
131136
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
132137
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
@@ -159,7 +164,10 @@ def __init__(
159164

160165
# 4. Define output layers
161166
if self.is_input_continuous:
162-
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
167+
if use_linear_projection:
168+
self.proj_out = nn.Linear(in_channels, inner_dim)
169+
else:
170+
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
163171
elif self.is_input_vectorized:
164172
self.norm_out = nn.LayerNorm(inner_dim)
165173
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
@@ -191,10 +199,18 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
191199
if self.is_input_continuous:
192200
batch, channel, height, weight = hidden_states.shape
193201
residual = hidden_states
202+
194203
hidden_states = self.norm(hidden_states)
195-
hidden_states = self.proj_in(hidden_states)
196-
inner_dim = hidden_states.shape[1]
197-
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
204+
205+
if not self.use_linear_projection:
206+
hidden_states = self.proj_in(hidden_states)
207+
inner_dim = hidden_states.shape[1]
208+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
209+
else:
210+
hidden_states = self.norm(hidden_states)
211+
inner_dim = hidden_states.shape[1]
212+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
213+
hidden_states = self.proj_in(hidden_states)
198214
elif self.is_input_vectorized:
199215
hidden_states = self.latent_image_embedding(hidden_states)
200216

@@ -204,8 +220,13 @@ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, retu
204220

205221
# 3. Output
206222
if self.is_input_continuous:
207-
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
208-
hidden_states = self.proj_out(hidden_states)
223+
if not self.use_linear_projection:
224+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
225+
hidden_states = self.proj_out(hidden_states)
226+
else:
227+
hidden_states = self.proj_out(hidden_states)
228+
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
229+
209230
output = hidden_states + residual
210231
elif self.is_input_vectorized:
211232
hidden_states = self.norm_out(hidden_states)

src/diffusers/models/unet_2d_blocks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_down_block(
3333
cross_attention_dim=None,
3434
downsample_padding=None,
3535
dual_cross_attention=False,
36+
use_linear_projection=False,
3637
):
3738
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
3839
if down_block_type == "DownBlock2D":
@@ -76,6 +77,7 @@ def get_down_block(
7677
cross_attention_dim=cross_attention_dim,
7778
attn_num_head_channels=attn_num_head_channels,
7879
dual_cross_attention=dual_cross_attention,
80+
use_linear_projection=use_linear_projection,
7981
)
8082
elif down_block_type == "SkipDownBlock2D":
8183
return SkipDownBlock2D(
@@ -140,6 +142,7 @@ def get_up_block(
140142
resnet_groups=None,
141143
cross_attention_dim=None,
142144
dual_cross_attention=False,
145+
use_linear_projection=False,
143146
):
144147
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
145148
if up_block_type == "UpBlock2D":
@@ -170,6 +173,7 @@ def get_up_block(
170173
cross_attention_dim=cross_attention_dim,
171174
attn_num_head_channels=attn_num_head_channels,
172175
dual_cross_attention=dual_cross_attention,
176+
use_linear_projection=use_linear_projection,
173177
)
174178
elif up_block_type == "AttnUpBlock2D":
175179
return AttnUpBlock2D(
@@ -327,6 +331,7 @@ def __init__(
327331
output_scale_factor=1.0,
328332
cross_attention_dim=1280,
329333
dual_cross_attention=False,
334+
use_linear_projection=False,
330335
**kwargs,
331336
):
332337
super().__init__()
@@ -362,6 +367,7 @@ def __init__(
362367
num_layers=1,
363368
cross_attention_dim=cross_attention_dim,
364369
norm_num_groups=resnet_groups,
370+
use_linear_projection=use_linear_projection,
365371
)
366372
)
367373
else:
@@ -523,6 +529,7 @@ def __init__(
523529
downsample_padding=1,
524530
add_downsample=True,
525531
dual_cross_attention=False,
532+
use_linear_projection=False,
526533
):
527534
super().__init__()
528535
resnets = []
@@ -556,6 +563,7 @@ def __init__(
556563
num_layers=1,
557564
cross_attention_dim=cross_attention_dim,
558565
norm_num_groups=resnet_groups,
566+
use_linear_projection=use_linear_projection,
559567
)
560568
)
561569
else:
@@ -1120,6 +1128,7 @@ def __init__(
11201128
output_scale_factor=1.0,
11211129
add_upsample=True,
11221130
dual_cross_attention=False,
1131+
use_linear_projection=False,
11231132
):
11241133
super().__init__()
11251134
resnets = []
@@ -1155,6 +1164,7 @@ def __init__(
11551164
num_layers=1,
11561165
cross_attention_dim=cross_attention_dim,
11571166
norm_num_groups=resnet_groups,
1167+
use_linear_projection=use_linear_projection,
11581168
)
11591169
)
11601170
else:

src/diffusers/models/unet_2d_condition.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
6161
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
6262
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
6363
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
64-
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
64+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
6565
Whether to flip the sin to cos in the time embedding.
6666
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
6767
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
@@ -106,8 +106,9 @@ def __init__(
106106
norm_num_groups: int = 32,
107107
norm_eps: float = 1e-5,
108108
cross_attention_dim: int = 1280,
109-
attention_head_dim: int = 8,
109+
attention_head_dim: Union[int, Tuple[int]] = 8,
110110
dual_cross_attention: bool = False,
111+
use_linear_projection: bool = False,
111112
):
112113
super().__init__()
113114

@@ -127,6 +128,9 @@ def __init__(
127128
self.mid_block = None
128129
self.up_blocks = nn.ModuleList([])
129130

131+
if isinstance(attention_head_dim, int):
132+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
133+
130134
# down
131135
output_channel = block_out_channels[0]
132136
for i, down_block_type in enumerate(down_block_types):
@@ -145,9 +149,10 @@ def __init__(
145149
resnet_act_fn=act_fn,
146150
resnet_groups=norm_num_groups,
147151
cross_attention_dim=cross_attention_dim,
148-
attn_num_head_channels=attention_head_dim,
152+
attn_num_head_channels=attention_head_dim[i],
149153
downsample_padding=downsample_padding,
150154
dual_cross_attention=dual_cross_attention,
155+
use_linear_projection=use_linear_projection,
151156
)
152157
self.down_blocks.append(down_block)
153158

@@ -160,16 +165,18 @@ def __init__(
160165
output_scale_factor=mid_block_scale_factor,
161166
resnet_time_scale_shift="default",
162167
cross_attention_dim=cross_attention_dim,
163-
attn_num_head_channels=attention_head_dim,
168+
attn_num_head_channels=attention_head_dim[-1],
164169
resnet_groups=norm_num_groups,
165170
dual_cross_attention=dual_cross_attention,
171+
use_linear_projection=use_linear_projection,
166172
)
167173

168174
# count how many layers upsample the images
169175
self.num_upsamplers = 0
170176

171177
# up
172178
reversed_block_out_channels = list(reversed(block_out_channels))
179+
reversed_attention_head_dim = list(reversed(attention_head_dim))
173180
output_channel = reversed_block_out_channels[0]
174181
for i, up_block_type in enumerate(up_block_types):
175182
is_final_block = i == len(block_out_channels) - 1
@@ -197,8 +204,9 @@ def __init__(
197204
resnet_act_fn=act_fn,
198205
resnet_groups=norm_num_groups,
199206
cross_attention_dim=cross_attention_dim,
200-
attn_num_head_channels=attention_head_dim,
207+
attn_num_head_channels=reversed_attention_head_dim[i],
201208
dual_cross_attention=dual_cross_attention,
209+
use_linear_projection=use_linear_projection,
202210
)
203211
self.up_blocks.append(up_block)
204212
prev_output_channel = output_channel
@@ -256,8 +264,7 @@ def forward(
256264
Args:
257265
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
258266
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
259-
encoder_hidden_states (`torch.FloatTensor`):
260-
(batch_size, sequence_length, hidden_size) encoder hidden states
267+
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
261268
return_dict (`bool`, *optional*, defaults to `True`):
262269
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
263270

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
124124
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
125125
out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
126126
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
127-
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
127+
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
128128
Whether to flip the sin to cos in the time embedding.
129129
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
130130
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
@@ -174,8 +174,9 @@ def __init__(
174174
norm_num_groups: int = 32,
175175
norm_eps: float = 1e-5,
176176
cross_attention_dim: int = 1280,
177-
attention_head_dim: int = 8,
177+
attention_head_dim: Union[int, Tuple[int]] = 8,
178178
dual_cross_attention: bool = False,
179+
use_linear_projection: bool = False,
179180
):
180181
super().__init__()
181182

@@ -195,6 +196,9 @@ def __init__(
195196
self.mid_block = None
196197
self.up_blocks = nn.ModuleList([])
197198

199+
if isinstance(attention_head_dim, int):
200+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
201+
198202
# down
199203
output_channel = block_out_channels[0]
200204
for i, down_block_type in enumerate(down_block_types):
@@ -213,9 +217,10 @@ def __init__(
213217
resnet_act_fn=act_fn,
214218
resnet_groups=norm_num_groups,
215219
cross_attention_dim=cross_attention_dim,
216-
attn_num_head_channels=attention_head_dim,
220+
attn_num_head_channels=attention_head_dim[i],
217221
downsample_padding=downsample_padding,
218222
dual_cross_attention=dual_cross_attention,
223+
use_linear_projection=use_linear_projection,
219224
)
220225
self.down_blocks.append(down_block)
221226

@@ -228,16 +233,18 @@ def __init__(
228233
output_scale_factor=mid_block_scale_factor,
229234
resnet_time_scale_shift="default",
230235
cross_attention_dim=cross_attention_dim,
231-
attn_num_head_channels=attention_head_dim,
236+
attn_num_head_channels=attention_head_dim[-1],
232237
resnet_groups=norm_num_groups,
233238
dual_cross_attention=dual_cross_attention,
239+
use_linear_projection=use_linear_projection,
234240
)
235241

236242
# count how many layers upsample the images
237243
self.num_upsamplers = 0
238244

239245
# up
240246
reversed_block_out_channels = list(reversed(block_out_channels))
247+
reversed_attention_head_dim = list(reversed(attention_head_dim))
241248
output_channel = reversed_block_out_channels[0]
242249
for i, up_block_type in enumerate(up_block_types):
243250
is_final_block = i == len(block_out_channels) - 1
@@ -265,8 +272,9 @@ def __init__(
265272
resnet_act_fn=act_fn,
266273
resnet_groups=norm_num_groups,
267274
cross_attention_dim=cross_attention_dim,
268-
attn_num_head_channels=attention_head_dim,
275+
attn_num_head_channels=reversed_attention_head_dim[i],
269276
dual_cross_attention=dual_cross_attention,
277+
use_linear_projection=use_linear_projection,
270278
)
271279
self.up_blocks.append(up_block)
272280
prev_output_channel = output_channel
@@ -324,8 +332,7 @@ def forward(
324332
Args:
325333
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
326334
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
327-
encoder_hidden_states (`torch.FloatTensor`):
328-
(batch_size, sequence_length, hidden_size) encoder hidden states
335+
encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
329336
return_dict (`bool`, *optional*, defaults to `True`):
330337
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
331338
@@ -640,6 +647,7 @@ def __init__(
640647
downsample_padding=1,
641648
add_downsample=True,
642649
dual_cross_attention=False,
650+
use_linear_projection=False,
643651
):
644652
super().__init__()
645653
resnets = []
@@ -673,6 +681,7 @@ def __init__(
673681
num_layers=1,
674682
cross_attention_dim=cross_attention_dim,
675683
norm_num_groups=resnet_groups,
684+
use_linear_projection=use_linear_projection,
676685
)
677686
)
678687
else:
@@ -851,6 +860,7 @@ def __init__(
851860
output_scale_factor=1.0,
852861
add_upsample=True,
853862
dual_cross_attention=False,
863+
use_linear_projection=False,
854864
):
855865
super().__init__()
856866
resnets = []
@@ -886,6 +896,7 @@ def __init__(
886896
num_layers=1,
887897
cross_attention_dim=cross_attention_dim,
888898
norm_num_groups=resnet_groups,
899+
use_linear_projection=use_linear_projection,
889900
)
890901
)
891902
else:
@@ -988,6 +999,7 @@ def __init__(
988999
output_scale_factor=1.0,
9891000
cross_attention_dim=1280,
9901001
dual_cross_attention=False,
1002+
use_linear_projection=False,
9911003
**kwargs,
9921004
):
9931005
super().__init__()
@@ -1023,6 +1035,7 @@ def __init__(
10231035
num_layers=1,
10241036
cross_attention_dim=cross_attention_dim,
10251037
norm_num_groups=resnet_groups,
1038+
use_linear_projection=use_linear_projection,
10261039
)
10271040
)
10281041
else:

0 commit comments

Comments
 (0)