Skip to content

Conversation

@hlky
Copy link
Contributor

@hlky hlky commented Jan 6, 2025

What does this PR do?

mid_block_type is unsupported in UNet2DModel despite appearing in documentation.

if mid_block_type == "UNetMidBlock2DCrossAttn":
return UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
resnet_groups=resnet_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
return UNetMidBlock2DSimpleCrossAttn(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
return UNetMidBlock2D(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
num_layers=0,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
return None

Of these UNet2DModel only works with UNetMidBlock2D. None can also be supported and is added in this PR.

Fixes #10460

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +198 to +200
if mid_block_type is None:
self.mid_block = None
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, if it's unsupported then should we not be disallowing its usage? I understand that's a breaking change and in that case, a deprecation would be preferred. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't want to support mid_block_type=None then we can just remove the docstring, the argument itself currently doesn't exist.

if self.mid_block is not None:

We support mid_block_type None in UNet2DConditionModel so it seemed appropriate to support in UNet2DModel.

We could use get_mid_block here but simpler like this as we'd need to guard against the other 2 CrossAttn mid blocks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support mid_block_type None in UNet2DConditionModel so it seemed appropriate to support in UNet2DModel.

I think since it's unsupported anyway in UNet2DModel, I am not seeing the reason to pass a None but if it's simpler to do so maintenance-wise, I am okay with that,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the linked issue #10460

>>> import diffusers
>>> diffusers.UNet2DModel(mid_block_type=None)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay works for me then but I will note that supporting just None values does feel weird to me.

@yiyixuxu yiyixuxu merged commit b13cdbb into huggingface:main Jan 8, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

UNet2DModel is missing the documented mid_block_type argument

5 participants