@@ -2095,6 +2095,8 @@ class UNetMidBlockFlat(nn.Module):
20952095        attention_head_dim (`int`, *optional*, defaults to 1): 
20962096            Dimension of a single attention head. The number of attention heads is determined based on this value and 
20972097            the number of input channels. 
2098+         attention_legacy_order (`bool`, *optional*, defaults to `False`): 
2099+             if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328 
20982100        output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. 
20992101
21002102    Returns: 
@@ -2110,21 +2112,22 @@ def __init__(
21102112        dropout : float  =  0.0 ,
21112113        num_layers : int  =  1 ,
21122114        resnet_eps : float  =  1e-6 ,
2113-         resnet_time_scale_shift : str  =  "default" ,  # default, spatial 
2115+         resnet_time_scale_shift : str  =  "default" ,  # default, spatial, scale_shift  
21142116        resnet_act_fn : str  =  "swish" ,
21152117        resnet_groups : int  =  32 ,
21162118        attn_groups : Optional [int ] =  None ,
21172119        resnet_pre_norm : bool  =  True ,
21182120        add_attention : bool  =  True ,
21192121        attention_head_dim : int  =  1 ,
2122+         attention_legacy_order : bool  =  False ,
21202123        output_scale_factor : float  =  1.0 ,
21212124    ):
21222125        super ().__init__ ()
21232126        resnet_groups  =  resnet_groups  if  resnet_groups  is  not   None  else  min (in_channels  //  4 , 32 )
21242127        self .add_attention  =  add_attention 
21252128
21262129        if  attn_groups  is  None :
2127-             attn_groups  =  resnet_groups  if  resnet_time_scale_shift  ==  "default "  else  None 
2130+             attn_groups  =  None  if  resnet_time_scale_shift  ==  "spatial "  else  resnet_groups 
21282131
21292132        # there is always at least one resnet 
21302133        if  resnet_time_scale_shift  ==  "spatial" :
@@ -2163,7 +2166,6 @@ def __init__(
21632166                f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: { in_channels }  ." 
21642167            )
21652168            attention_head_dim  =  in_channels 
2166- 
21672169        for  _  in  range (num_layers ):
21682170            if  self .add_attention :
21692171                attentions .append (
@@ -2179,6 +2181,7 @@ def __init__(
21792181                        bias = True ,
21802182                        upcast_softmax = True ,
21812183                        _from_deprecated_attn_block = True ,
2184+                         attention_legacy_order = attention_legacy_order ,
21822185                    )
21832186                )
21842187            else :
0 commit comments