@@ -2095,6 +2095,8 @@ class UNetMidBlockFlat(nn.Module):
20952095 attention_head_dim (`int`, *optional*, defaults to 1):
20962096 Dimension of a single attention head. The number of attention heads is determined based on this value and
20972097 the number of input channels.
2098+ attention_legacy_order (`bool`, *optional*, defaults to `False`):
2099+ if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
20982100 output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
20992101
21002102 Returns:
@@ -2110,21 +2112,22 @@ def __init__(
21102112 dropout : float = 0.0 ,
21112113 num_layers : int = 1 ,
21122114 resnet_eps : float = 1e-6 ,
2113- resnet_time_scale_shift : str = "default" , # default, spatial
2115+ resnet_time_scale_shift : str = "default" , # default, spatial, scale_shift
21142116 resnet_act_fn : str = "swish" ,
21152117 resnet_groups : int = 32 ,
21162118 attn_groups : Optional [int ] = None ,
21172119 resnet_pre_norm : bool = True ,
21182120 add_attention : bool = True ,
21192121 attention_head_dim : int = 1 ,
2122+ attention_legacy_order : bool = False ,
21202123 output_scale_factor : float = 1.0 ,
21212124 ):
21222125 super ().__init__ ()
21232126 resnet_groups = resnet_groups if resnet_groups is not None else min (in_channels // 4 , 32 )
21242127 self .add_attention = add_attention
21252128
21262129 if attn_groups is None :
2127- attn_groups = resnet_groups if resnet_time_scale_shift == "default " else None
2130+ attn_groups = None if resnet_time_scale_shift == "spatial " else resnet_groups
21282131
21292132 # there is always at least one resnet
21302133 if resnet_time_scale_shift == "spatial" :
@@ -2163,7 +2166,6 @@ def __init__(
21632166 f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: { in_channels } ."
21642167 )
21652168 attention_head_dim = in_channels
2166-
21672169 for _ in range (num_layers ):
21682170 if self .add_attention :
21692171 attentions .append (
@@ -2179,6 +2181,7 @@ def __init__(
21792181 bias = True ,
21802182 upcast_softmax = True ,
21812183 _from_deprecated_attn_block = True ,
2184+ attention_legacy_order = attention_legacy_order ,
21822185 )
21832186 )
21842187 else :
0 commit comments