Skip to content

Commit 7a1c402

Browse files
authored
[Kernel] [CPU] refactor cpu_attn.py:_run_sdpa_forward for better memory access (vllm-project#24701)
Signed-off-by: ignaciosica <[email protected]>
1 parent 60a0951 commit 7a1c402

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

vllm/v1/attention/backends/cpu_attn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,10 +641,6 @@ def _run_sdpa_forward(
641641
attn_metadata: TorchSDPAMetadata,
642642
attn_type: str = AttentionType.DECODER,
643643
) -> None:
644-
if self.num_kv_heads != self.num_heads:
645-
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
646-
value = value.repeat_interleave(self.num_queries_per_kv, dim=1)
647-
648644
attn_masks = attn_metadata.get_attn_bias(attn_type)
649645
if attn_masks is None:
650646
if self.alibi_slopes is not None:
@@ -665,6 +661,10 @@ def _run_sdpa_forward(
665661
key = key.movedim(0, key.dim() - 2)
666662
value = value.movedim(0, value.dim() - 2)
667663

664+
if self.num_kv_heads != self.num_heads:
665+
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
666+
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
667+
668668
causal_attn = (attn_type == AttentionType.DECODER)
669669

670670
seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type)

0 commit comments

Comments
 (0)