Skip to content

Commit c5c437a

Browse files
committed
remove einops dependencies and reimplement with torch functions
1 parent b82f9e8 commit c5c437a

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

timm/models/pe.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66
import torch.nn as nn
7-
from einops import rearrange, repeat
87
from torch import nn, Tensor, broadcast_tensors, einsum
98
from torch.nn import functional as F
109
from torch.nn import Module, ModuleList
@@ -49,17 +48,14 @@
4948
def exists(val):
5049
return val is not None
5150

52-
5351
def default(val, d):
5452
return val if exists(val) else d
5553

56-
5754
def rotate_half(x):
58-
x = rearrange(x, "... (d r) -> ... d r", r=2)
59-
x1, x2 = x.unbind(dim=-1)
55+
x = x.view(*x.shape[:-1], -1, 2)
56+
x1, x2 = x[..., 0], x[..., 1]
6057
x = torch.stack((-x2, x1), dim=-1)
61-
return rearrange(x, "... d r -> ... (d r)")
62-
58+
return x.view(*x.shape[:-2], -1)
6359

6460
@autocast("cuda", enabled=False)
6561
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
@@ -86,7 +82,6 @@ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
8682

8783
return out.type(dtype)
8884

89-
9085
class RotaryEmbedding(Module):
9186
def __init__(
9287
self,
@@ -187,7 +182,7 @@ def rotate_queries_or_keys(self, t, seq_dim=None, offset=0):
187182
)
188183

189184
if seq_dim == -3:
190-
freqs = rearrange(freqs, "n d -> n 1 d")
185+
freqs = freqs.unsqueeze(1)
191186

192187
return apply_rotary_emb(freqs, t, seq_dim=seq_dim)
193188

@@ -217,8 +212,8 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None):
217212
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
218213

219214
if seq_dim == -3:
220-
freqs = rearrange(freqs, "n d -> n 1 d")
221-
scale = rearrange(scale, "n d -> n 1 d")
215+
freqs = freqs.unsqueeze(1)
216+
scale = scale.unsqueeze(1)
222217

223218
rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
224219
rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim)
@@ -230,7 +225,6 @@ def rotate_queries_and_keys(self, q, k, seq_dim=None):
230225

231226
def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
232227
assert self.use_xpos
233-
234228
should_cache = self.cache_if_possible and exists(seq_len)
235229

236230
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales.shape[0]:
@@ -239,7 +233,7 @@ def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0):
239233
scale = 1.0
240234
if self.use_xpos:
241235
power = (t - len(t) // 2) / self.scale_base
242-
scale = self.scale ** rearrange(power, "n -> n 1")
236+
scale = self.scale ** power.unsqueeze(-1)
243237
scale = torch.cat((scale, scale), dim=-1)
244238

245239
if should_cache:
@@ -280,7 +274,7 @@ def forward(self, t: Tensor, seq_len=None, offset=0):
280274
freqs = self.freqs
281275

282276
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
283-
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
277+
freqs = freqs.repeat_interleave(2, dim=-1)
284278

285279
if should_cache:
286280
self.tmp_store("cached_freqs", freqs.detach())
@@ -414,15 +408,15 @@ def forward(self, x, attn_mask=None):
414408
q, k, v = proj[0], proj[1], proj[2]
415409

416410
# Use "q_" so that we don't accidentally quit in pdb :)
417-
q = rearrange(q, "b s (h d) -> b h s d", h=self.num_heads)
418-
k = rearrange(k, "b s (h d) -> b h s d", h=self.num_heads)
419-
v = rearrange(v, "b s (h d) -> b h s d", h=self.num_heads)
411+
q = q.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
412+
k = k.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
413+
v = v.view(batch, seq, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
420414

421415
if self.rope:
422416
q, k = self.rope(q, k)
423417

424418
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale)
425-
attn = rearrange(attn, "b h s d -> b s (h d)")
419+
attn = attn.permute(0, 2, 1, 3).contiguous().view(batch, seq, -1)
426420

427421
return F.linear(attn, self.out_proj.weight, self.out_proj.bias)
428422

0 commit comments

Comments
 (0)