File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
vllm/v1/attention/backends Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -641,10 +641,6 @@ def _run_sdpa_forward(
641
641
attn_metadata : TorchSDPAMetadata ,
642
642
attn_type : str = AttentionType .DECODER ,
643
643
) -> None :
644
- if self .num_kv_heads != self .num_heads :
645
- key = key .repeat_interleave (self .num_queries_per_kv , dim = 1 )
646
- value = value .repeat_interleave (self .num_queries_per_kv , dim = 1 )
647
-
648
644
attn_masks = attn_metadata .get_attn_bias (attn_type )
649
645
if attn_masks is None :
650
646
if self .alibi_slopes is not None :
@@ -665,6 +661,10 @@ def _run_sdpa_forward(
665
661
key = key .movedim (0 , key .dim () - 2 )
666
662
value = value .movedim (0 , value .dim () - 2 )
667
663
664
+ if self .num_kv_heads != self .num_heads :
665
+ key = key .repeat_interleave (self .num_queries_per_kv , dim = - 3 )
666
+ value = value .repeat_interleave (self .num_queries_per_kv , dim = - 3 )
667
+
668
668
causal_attn = (attn_type == AttentionType .DECODER )
669
669
670
670
seq_lens_q , seq_lens_kv = attn_metadata .get_seq_lens (attn_type )
You can’t perform that action at this time.
0 commit comments