Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 54e0441

Browse files
authored
Revert "remove redundant slice; match decode PA partition size with csrc (#188)" (#194)
This reverts commit c68c242.
1 parent 7d3690c commit 54e0441

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
1919

2020
logger = init_logger(__name__)
21-
#keep _PARTITION_SIZE in sync with csrc/rocm/attention.cu
22-
_PARTITION_SIZE = 512
21+
22+
_PARTITION_SIZE = 256
2323
ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
2424

2525

@@ -517,10 +517,7 @@ def forward(
517517

518518
# common code for prefill
519519
assert output[:num_prefill_tokens].shape == out.shape
520-
if output.shape[0] > num_prefill_tokens:
521-
output[:num_prefill_tokens] = out
522-
else:
523-
output = out
520+
output[:num_prefill_tokens] = out
524521
else:
525522
# prefix-enabled attention
526523
output[:num_prefill_tokens] = PagedAttention.forward_prefix(
@@ -567,13 +564,11 @@ def forward(
567564
)
568565
max_logits = torch.empty_like(exp_sums)
569566
ops.paged_attention_rocm(
570-
output, exp_sums, max_logits, tmp_output, decode_query,
571-
key_cache, value_cache, self.num_kv_heads, self.scale,
572-
decode_meta.block_tables, decode_meta.seq_lens_tensor,
573-
block_size, max_seq_len, self.alibi_slopes,
574-
self.kv_cache_dtype, k_scale, v_scale)
575-
if num_prefill_tokens > 0:
576-
output = output[num_prefill_tokens:]
567+
output[num_prefill_tokens:], exp_sums, max_logits,
568+
tmp_output, decode_query, key_cache, value_cache,
569+
self.num_kv_heads, self.scale, decode_meta.block_tables,
570+
decode_meta.seq_lens_tensor, block_size, max_seq_len,
571+
self.alibi_slopes, self.kv_cache_dtype, k_scale, v_scale)
577572
else:
578573
output[num_prefill_tokens:] = PagedAttention.forward_decode(
579574
decode_query,

0 commit comments

Comments
 (0)