|
18 | 18 | from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
19 | 19 |
|
20 | 20 | logger = init_logger(__name__)
|
21 |
| -#keep _PARTITION_SIZE in sync with csrc/rocm/attention.cu |
22 |
| -_PARTITION_SIZE = 512 |
| 21 | + |
| 22 | +_PARTITION_SIZE = 256 |
23 | 23 | ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
|
24 | 24 |
|
25 | 25 |
|
@@ -517,10 +517,7 @@ def forward(
|
517 | 517 |
|
518 | 518 | # common code for prefill
|
519 | 519 | assert output[:num_prefill_tokens].shape == out.shape
|
520 |
| - if output.shape[0] > num_prefill_tokens: |
521 |
| - output[:num_prefill_tokens] = out |
522 |
| - else: |
523 |
| - output = out |
| 520 | + output[:num_prefill_tokens] = out |
524 | 521 | else:
|
525 | 522 | # prefix-enabled attention
|
526 | 523 | output[:num_prefill_tokens] = PagedAttention.forward_prefix(
|
@@ -567,13 +564,11 @@ def forward(
|
567 | 564 | )
|
568 | 565 | max_logits = torch.empty_like(exp_sums)
|
569 | 566 | ops.paged_attention_rocm(
|
570 |
| - output, exp_sums, max_logits, tmp_output, decode_query, |
571 |
| - key_cache, value_cache, self.num_kv_heads, self.scale, |
572 |
| - decode_meta.block_tables, decode_meta.seq_lens_tensor, |
573 |
| - block_size, max_seq_len, self.alibi_slopes, |
574 |
| - self.kv_cache_dtype, k_scale, v_scale) |
575 |
| - if num_prefill_tokens > 0: |
576 |
| - output = output[num_prefill_tokens:] |
| 567 | + output[num_prefill_tokens:], exp_sums, max_logits, |
| 568 | + tmp_output, decode_query, key_cache, value_cache, |
| 569 | + self.num_kv_heads, self.scale, decode_meta.block_tables, |
| 570 | + decode_meta.seq_lens_tensor, block_size, max_seq_len, |
| 571 | + self.alibi_slopes, self.kv_cache_dtype, k_scale, v_scale) |
577 | 572 | else:
|
578 | 573 | output[num_prefill_tokens:] = PagedAttention.forward_decode(
|
579 | 574 | decode_query,
|
|
0 commit comments