Skip to content

Commit 5b743f5

Browse files
authored
Static attention: do not specialize on input sequence length
Differential Revision: D80181012 Pull Request resolved: #13373
1 parent 0ad3df9 commit 5b743f5

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/models/llama/static_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def forward(
759759

760760
bsz, seq_len, dim = x.shape
761761
if self.use_conv2d:
762-
x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3)
762+
x = x.reshape(bsz, -1, 1, dim).transpose(1, 3)
763763

764764
new_qs = [wq(x) for wq in self.wqs]
765765
new_ks = [wk(x) for wk in self.wks]
@@ -768,9 +768,7 @@ def forward(
768768
if self.use_conv2d:
769769

770770
def from_conv2ds(ts):
771-
return [
772-
t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts
773-
]
771+
return [t.reshape(bsz, self.head_dim, -1).transpose(1, 2) for t in ts]
774772

775773
new_qs = from_conv2ds(new_qs)
776774
new_ks = from_conv2ds(new_ks)
@@ -800,9 +798,11 @@ def from_conv2ds(ts):
800798

801799
if self.use_conv2d:
802800
y = (
803-
self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3))
801+
self.wo(
802+
y.reshape(bsz, -1, 1, self.n_heads * self.head_dim).transpose(1, 3)
803+
)
804804
.transpose(1, 3)
805-
.reshape(bsz, seq_len, -1)
805+
.reshape(bsz, -1, self.dim)
806806
)
807807
else:
808808
y = self.wo(y)

0 commit comments

Comments
 (0)