diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index eee4aacf44d..60183801b42 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -352,26 +352,18 @@ def forward( # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. - # k: [bsz, seq_len, n_kv, 1, h_d] - # v: [bsz, seq_len, n_kv, 1, h_d] - k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) - v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) - - # Expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) - v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) - - # [bsz, s, n_h, h_d] - k = k.reshape(bsz, -1, self.num_heads, self.head_dim) - v = v.reshape(bsz, -1, self.num_heads, self.head_dim) - # [bsz, n_h, s, h_d] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + expand_shape = (-1, -1, self.q_per_kv, -1, -1) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + output = self._attention_fn( q, k,