Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions src/diffusers/models/unets/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Expand Down Expand Up @@ -103,6 +103,7 @@ def __init__(
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
Expand Down Expand Up @@ -194,19 +195,22 @@ def __init__(
self.down_blocks.append(down_block)

# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
if mid_block_type is None:
self.mid_block = None
else:
Comment on lines +198 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if it's unsupported then should we not be disallowing its usage? I understand that's a breaking change and in that case, a deprecation would be preferred. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to support mid_block_type=None then we can just remove the docstring, the argument itself currently doesn't exist.

if self.mid_block is not None:

We support mid_block_type None in UNet2DConditionModel so it seemed appropriate to support in UNet2DModel.

We could use get_mid_block here but simpler like this as we'd need to guard against the other 2 CrossAttn mid blocks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support mid_block_type None in UNet2DConditionModel so it seemed appropriate to support in UNet2DModel.

I think since it's unsupported anyway in UNet2DModel, I am not seeing the reason to pass a None but if it's simpler to do so maintenance-wise, I am okay with that,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the linked issue #10460

>>> import diffusers
>>> diffusers.UNet2DModel(mid_block_type=None)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay works for me then but I will note that supporting just None values does feel weird to me.

self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)

# up
reversed_block_out_channels = list(reversed(block_out_channels))
Expand Down Expand Up @@ -322,7 +326,8 @@ def forward(
down_block_res_samples += res_samples

# 4. mid
sample = self.mid_block(sample, emb)
if self.mid_block is not None:
sample = self.mid_block(sample, emb)

# 5. up
skip_sample = None
Expand Down
29 changes: 29 additions & 0 deletions tests/models/unets/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

mid_none_model = self.model_class(**mid_none_init_dict)
mid_none_model.to(torch_device)
mid_none_model.eval()

self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")

with torch.no_grad():
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.to_tuple()[0]

with torch.no_grad():
mid_none_output = mid_none_model(**mid_none_inputs_dict)

if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]

self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
Expand Down
Loading