Skip to content

Commit 284cd8e

Browse files
committed
pass copy check
1 parent 7518788 commit 284cd8e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,6 +2095,8 @@ class UNetMidBlockFlat(nn.Module):
20952095
attention_head_dim (`int`, *optional*, defaults to 1):
20962096
Dimension of a single attention head. The number of attention heads is determined based on this value and
20972097
the number of input channels.
2098+
attention_legacy_order (`bool`, *optional*, defaults to `False`):
2099+
if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
20982100
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
20992101
21002102
Returns:
@@ -2110,21 +2112,22 @@ def __init__(
21102112
dropout: float = 0.0,
21112113
num_layers: int = 1,
21122114
resnet_eps: float = 1e-6,
2113-
resnet_time_scale_shift: str = "default", # default, spatial
2115+
resnet_time_scale_shift: str = "default", # default, spatial, scale_shift
21142116
resnet_act_fn: str = "swish",
21152117
resnet_groups: int = 32,
21162118
attn_groups: Optional[int] = None,
21172119
resnet_pre_norm: bool = True,
21182120
add_attention: bool = True,
21192121
attention_head_dim: int = 1,
2122+
attention_legacy_order: bool = False,
21202123
output_scale_factor: float = 1.0,
21212124
):
21222125
super().__init__()
21232126
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
21242127
self.add_attention = add_attention
21252128

21262129
if attn_groups is None:
2127-
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
2130+
attn_groups = None if resnet_time_scale_shift == "spatial" else resnet_groups
21282131

21292132
# there is always at least one resnet
21302133
if resnet_time_scale_shift == "spatial":
@@ -2163,7 +2166,6 @@ def __init__(
21632166
f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
21642167
)
21652168
attention_head_dim = in_channels
2166-
21672169
for _ in range(num_layers):
21682170
if self.add_attention:
21692171
attentions.append(
@@ -2179,6 +2181,7 @@ def __init__(
21792181
bias=True,
21802182
upcast_softmax=True,
21812183
_from_deprecated_attn_block=True,
2184+
attention_legacy_order=attention_legacy_order,
21822185
)
21832186
)
21842187
else:

0 commit comments

Comments
 (0)