@@ -64,7 +64,7 @@ def __init__(
6464 num_heads : int = 8 ,
6565 qkv_bias : bool = False ,
6666 qk_norm : bool = False ,
67- scale_attn_norm : bool = False ,
67+ scale_norm : bool = False ,
6868 proj_bias : bool = True ,
6969 attn_drop : float = 0. ,
7070 proj_drop : float = 0. ,
@@ -80,7 +80,7 @@ def __init__(
8080 self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
8181 self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
8282 self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
83- self .norm = norm_layer (dim ) if scale_attn_norm else nn .Identity ()
83+ self .norm = norm_layer (dim ) if scale_norm else nn .Identity ()
8484 self .attn_drop = nn .Dropout (attn_drop )
8585 self .proj = nn .Linear (dim , dim , bias = proj_bias )
8686 self .proj_drop = nn .Dropout (proj_drop )
@@ -151,7 +151,7 @@ def __init__(
151151 num_heads = num_heads ,
152152 qkv_bias = qkv_bias ,
153153 qk_norm = qk_norm ,
154- scale_attn_norm = scale_attn_norm ,
154+ scale_norm = scale_attn_norm ,
155155 proj_bias = proj_bias ,
156156 attn_drop = attn_drop ,
157157 proj_drop = proj_drop ,
@@ -205,7 +205,7 @@ def __init__(
205205 num_heads = num_heads ,
206206 qkv_bias = qkv_bias ,
207207 qk_norm = qk_norm ,
208- scale_attn_norm = scale_attn_norm ,
208+ scale_norm = scale_attn_norm ,
209209 proj_bias = proj_bias ,
210210 attn_drop = attn_drop ,
211211 proj_drop = proj_drop ,
@@ -253,6 +253,8 @@ def __init__(
253253 mlp_ratio : float = 4. ,
254254 qkv_bias : bool = False ,
255255 qk_norm : bool = False ,
256+ scale_attn_norm : bool = False ,
257+ scale_mlp_norm : bool = False ,
256258 proj_bias : bool = True ,
257259 proj_drop : float = 0. ,
258260 attn_drop : float = 0. ,
@@ -264,6 +266,7 @@ def __init__(
264266 ) -> None :
265267 super ().__init__ ()
266268 assert dim % num_heads == 0 , 'dim should be divisible by num_heads'
269+ assert not scale_attn_norm and not scale_mlp_norm , 'Scale norms not supported'
267270 self .num_heads = num_heads
268271 self .head_dim = dim // num_heads
269272 self .scale = self .head_dim ** - 0.5
@@ -348,6 +351,8 @@ def __init__(
348351 mlp_ratio : float = 4. ,
349352 qkv_bias : bool = False ,
350353 qk_norm : bool = False ,
354+ scale_attn_norm : bool = False ,
355+ scale_mlp_norm : bool = False ,
351356 proj_bias : bool = True ,
352357 init_values : Optional [float ] = None ,
353358 proj_drop : float = 0. ,
@@ -369,6 +374,7 @@ def __init__(
369374 num_heads = num_heads ,
370375 qkv_bias = qkv_bias ,
371376 qk_norm = qk_norm ,
377+ scale_norm = scale_attn_norm ,
372378 proj_bias = proj_bias ,
373379 attn_drop = attn_drop ,
374380 proj_drop = proj_drop ,
@@ -383,6 +389,7 @@ def __init__(
383389 dim ,
384390 hidden_features = int (dim * mlp_ratio ),
385391 act_layer = act_layer ,
392+ norm_layer = norm_layer if scale_mlp_norm else None ,
386393 bias = proj_bias ,
387394 drop = proj_drop ,
388395 )),
0 commit comments