diff --git a/examples/models/llama/static_attention.py b/examples/models/llama/static_attention.py index e3859b98210..5ffd25f2c7f 100644 --- a/examples/models/llama/static_attention.py +++ b/examples/models/llama/static_attention.py @@ -759,7 +759,7 @@ def forward( bsz, seq_len, dim = x.shape if self.use_conv2d: - x = x.reshape(bsz, seq_len, 1, dim).transpose(1, 3) + x = x.reshape(bsz, -1, 1, dim).transpose(1, 3) new_qs = [wq(x) for wq in self.wqs] new_ks = [wk(x) for wk in self.wks] @@ -768,9 +768,7 @@ def forward( if self.use_conv2d: def from_conv2ds(ts): - return [ - t.reshape(bsz, self.head_dim, seq_len).transpose(1, 2) for t in ts - ] + return [t.reshape(bsz, self.head_dim, -1).transpose(1, 2) for t in ts] new_qs = from_conv2ds(new_qs) new_ks = from_conv2ds(new_ks) @@ -800,9 +798,11 @@ def from_conv2ds(ts): if self.use_conv2d: y = ( - self.wo(y.reshape(bsz, seq_len, 1, -1).transpose(1, 3)) + self.wo( + y.reshape(bsz, -1, 1, self.n_heads * self.head_dim).transpose(1, 3) + ) .transpose(1, 3) - .reshape(bsz, seq_len, -1) + .reshape(bsz, -1, self.dim) ) else: y = self.wo(y)