Skip to content

Commit 936d20e

Browse files
committed
reuse fused_attn from timm, add activation between proj and cls_head
1 parent 89d348d commit 936d20e

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

timm/models/pe.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,16 @@
1111
from torch.nn.parameter import Parameter
1212
from torch.amp import autocast
1313
from torch.utils.checkpoint import checkpoint
14+
from torch.jit import Final
15+
1416

1517
### Import timm layers
1618
from timm.layers import (
1719
DropPath,
1820
AttentionPoolLatent,
1921
LayerType,
2022
LayerScale,
23+
use_fused_attn,
2124
)
2225

2326
# from timm.layers import RotaryEmbeddingCat, RotaryEmbedding # not compatible
@@ -70,6 +73,7 @@ def forward(self, t: Tensor):
7073
return freqs
7174

7275

76+
7377
@register_notrace_module
7478
class Rope2D(Module):
7579
def __init__(self, dim, grid_size, use_cls_token=False):
@@ -181,6 +185,8 @@ class SelfAttention(nn.Module):
181185
r"""
182186
Implements sequence packed attention and RoPe
183187
"""
188+
fused_attn: Final[bool]
189+
184190
def __init__(
185191
self,
186192
embed_dim: int,
@@ -201,12 +207,14 @@ def __init__(
201207

202208
self.rope = rope
203209
self.scale = self.head_dim ** (-0.5)
210+
self.fused_attn = use_fused_attn()
204211

205212
def init_tensors(self):
206213
xavier_uniform_(self.in_proj_weight)
207214
constant_(self.in_proj_bias, 0.0)
208215
constant_(self.out_proj.bias, 0.0)
209216

217+
210218
def forward(self,
211219
x: torch.Tensor,
212220
attn_mask: Optional[torch.Tensor] = None,
@@ -226,12 +234,21 @@ def forward(self,
226234
if self.rope is not None:
227235
q, k = self.rope(q, k)
228236

229-
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
237+
if self.fused_attn:
238+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
239+
else:
240+
q = q * self.scale
241+
attn = q @ k.transpose(-2, -1)
242+
attn = attn.softmax(dim=-1)
243+
attn = attn @ v
244+
230245
attn = attn.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1)
231246

232247
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
233248

234249

250+
251+
235252
class ResidualAttentionBlock(nn.Module):
236253
def __init__(
237254
self,
@@ -246,10 +263,7 @@ def __init__(
246263
):
247264
super().__init__()
248265

249-
#if rope:
250266
self.attn = SelfAttention(d_model, n_head, rope=rope)
251-
#else:
252-
# self.attn = nn.MultiheadAttention(d_model, n_head, batch_first=True)
253267

254268
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
255269
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
@@ -281,10 +295,7 @@ def _call_attn(
281295
if not attn_mask.dtype == torch.bool:
282296
attn_mask = attn_mask.to(q_x.dtype)
283297

284-
#if isinstance(self.attn, SelfAttention):
285298
return self.attn(q_x, attn_mask=attn_mask)
286-
#else:
287-
# return self.attn(q_x, q_x, q_x, attn_mask=attn_mask, need_weights=False)[0]
288299

289300
def forward(
290301
self,
@@ -392,6 +403,7 @@ def __init__(
392403
self.in_chans = in_chans
393404
self.num_classes = num_classes
394405
self.drop_rate = drop_rate
406+
self.emb_dim = width
395407

396408
# PE contains an (optional) projection layer
397409
# Flow: x -> Transfomer(x) -> pool -> proj -> head (for timm).
@@ -410,6 +422,7 @@ def __init__(
410422
self.num_features = width
411423

412424
self.num_classes = num_classes
425+
self.output_dim = output_dim
413426

414427
self.use_abs_posemb = use_abs_posemb
415428
self.use_cls_token = use_cls_token
@@ -466,6 +479,7 @@ def __init__(
466479
else:
467480
self.attn_pool = None
468481

482+
self.act_layer_cfg = act_layer
469483
self.init_tensors()
470484

471485
def init_tensors(self):
@@ -523,8 +537,10 @@ def forward_pool_and_proj(self, x: torch.Tensor):
523537

524538
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
525539
# PE has an additional proj layer: Transfomer(x) -> pool -> proj -> head (for timm).
526-
# Ideally pool To discuss with Ross where to split
540+
# To discuss with Ross where to split
527541
x = self.forward_pool_and_proj(x)
542+
if self.head_act_layer is not None:
543+
x = self.head_act_layer(x)
528544
return x if pre_logits else self.head(x)
529545

530546
def forward_features(self, x: torch.Tensor, norm: bool = False):
@@ -806,5 +822,4 @@ def vit_pe_spatial_gigantic_patch14_448(pretrained=False, **kwargs):
806822
ls_init_value=0.1,
807823
use_proj=False,
808824
)
809-
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
810-
825+
return _create_pe('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)