From d25f7f60685764786f99dc416fbc6617e5248784 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 15 Nov 2024 16:31:33 -0800 Subject: [PATCH] [llama-mm] Reduce copies in SDPA in MHA Summary: As titled. Align the implementation for SDPA with the torchtune version https://github.com/pytorch/torchtune/blob/main/torchtune/modules/attention.py#L267 Test Plan: Rely on unit tests Reviewers: Subscribers: Tasks: Tags: --- extension/llm/modules/attention.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index eee4aacf44d..60183801b42 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -352,26 +352,18 @@ def forward( # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q. - # k: [bsz, seq_len, n_kv, 1, h_d] - # v: [bsz, seq_len, n_kv, 1, h_d] - k = k.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) - v = v.view(bsz, -1, self.num_kv_heads, 1, self.head_dim) - - # Expand the key and value tensors to have the same shape - # as the query tensor by copying values across the relevant dim - if self.num_heads != self.num_kv_heads: - k = k.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) - v = v.expand(bsz, -1, self.num_kv_heads, self.q_per_kv, self.head_dim) - - # [bsz, s, n_h, h_d] - k = k.reshape(bsz, -1, self.num_heads, self.head_dim) - v = v.reshape(bsz, -1, self.num_heads, self.head_dim) - # [bsz, n_h, s, h_d] q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) + # Expand the key and value tensors to have the same shape + # as the query tensor by copying values across the relevant dim + if self.num_heads != self.num_kv_heads: + expand_shape = (-1, -1, self.q_per_kv, -1, -1) + k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) + v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) + output = self._attention_fn( q, k,