Skip to content

Commit 78aa33e

Browse files
author
Aleksandr Malyshev
committed
updated logic for attn selection with default split attn and increased size for CAR
1 parent b193a40 commit 78aa33e

File tree

4 files changed

+36
-17
lines changed

4 files changed

+36
-17
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def forward(
784784
attn_masks[0][None]
785785
if attn_masks is not None else None,
786786
full_scales,
787-
layer._out_scale,
787+
output_scale,
788788
)
789789
else:
790790
output[:num_prefill_tokens] = self.triton_attn_func(

vllm/distributed/device_communicators/custom_all_reduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class CustomAllreduce:
5454
def __init__(self,
5555
group: ProcessGroup,
5656
device: Union[int, str, torch.device],
57-
max_size=2 * 8192 * 1024) -> None:
57+
max_size=8 * 8192 * 1024) -> None:
5858
"""
5959
Args:
6060
group: the process group to work on. If None, it will use the

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ def build(self,
325325
dtype=torch.uint8,
326326
device=self.device,
327327
)
328+
if max_query_len > 1:
328329
# We pre-compute cumulative seq len needed for prefill attention
329330
# here to avoid recomputing it for every layer
330331
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,

vllm/v1/attention/backends/triton_attn.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
if current_platform.is_rocm():
3131
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
3232
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
33-
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
33+
from aiter.ops.triton.fused_kv_cache import (
34+
fused_qk_rope_reshape_and_cache)
35+
3436

3537
@dataclass
3638
class TritonAttentionMetadata:
@@ -250,23 +252,24 @@ def __init__(
250252
"TritonAttentionImpl")
251253

252254
self.fp8_dtype = current_platform.fp8_dtype()
253-
self.force_prefill_decode_attn = \
254-
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
255255

256256
# If not using prefill decode attention, we use the Triton
257257
# unified attention implementation.
258258
if use_aiter_unified_attention():
259259
logger.info_once(
260260
"Using aiter unified attention for TritonAttentionImpl")
261-
from aiter.ops.triton.unified_attention import (
262-
unified_attention)
261+
from aiter.ops.triton.unified_attention import unified_attention
263262
self.unified_attention = unified_attention
264-
elif not self.force_prefill_decode_attn:
263+
elif not envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION:
265264
logger.info_once(
266265
"Using vllm unified attention for TritonAttentionImpl")
267266
from vllm.attention.ops.triton_unified_attention import (
268267
unified_attention)
269268
self.unified_attention = unified_attention
269+
else:
270+
logger.info_once(
271+
"Using vllm split prefill decode attention for TritonAttentionImpl"
272+
)
270273

271274
self.sinks = sinks
272275
if sinks is not None:
@@ -324,30 +327,45 @@ def forward(
324327
# Whenever making a change in this method, please benchmark the
325328
# performance to make sure it does not introduce any overhead.
326329

327-
use_prefill_decode_attn = self.force_prefill_decode_attn
330+
use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION \
331+
and not use_aiter_unified_attention()
328332
num_actual_tokens = attn_metadata.num_actual_tokens
329333

330334
if use_prefill_decode_attn:
331335
key_cache, value_cache = PagedAttention.split_kv_cache(
332336
kv_cache, self.num_kv_heads, self.head_size)
333337
else:
334338
key_cache, value_cache = kv_cache.unbind(0)
335-
339+
336340
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
337-
assert self.kv_sharing_target_layer_name is None, "self.kv_sharing_target_layer_name error"
338-
cos, sin = cos_sin_cache.chunk(2, dim = -1)
341+
assert self.kv_sharing_target_layer_name is None, "self.kv_sharing_target_layer_name error"
342+
cos, sin = cos_sin_cache.chunk(2, dim=-1)
339343
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
340344
if is_fp8_kv_cache:
341345
key_cache_og_dtype = key_cache.dtype
342346
value_cache_og_dtype = value_cache.dtype
343347
key_cache = key_cache.view(self.fp8_dtype)
344348
value_cache = value_cache.view(self.fp8_dtype)
345349
query, key, key_cache, value_cache, output = fused_qk_rope_reshape_and_cache(
346-
query, key, value, key_cache, value_cache, attn_metadata.slot_mapping,
347-
positions, cos, sin,
348-
layer._k_scale, layer._v_scale,
349-
is_neox,
350-
flash_layout=(not use_prefill_decode_attn), apply_scale=is_fp8_kv_cache, offs=None, q_out=query, k_out=key, output_zeros=True, zeros_out=output)
350+
query,
351+
key,
352+
value,
353+
key_cache,
354+
value_cache,
355+
attn_metadata.slot_mapping,
356+
positions,
357+
cos,
358+
sin,
359+
layer._k_scale,
360+
layer._v_scale,
361+
is_neox,
362+
flash_layout=(not use_prefill_decode_attn),
363+
apply_scale=is_fp8_kv_cache,
364+
offs=None,
365+
q_out=query,
366+
k_out=key,
367+
output_zeros=True,
368+
zeros_out=output)
351369
if is_fp8_kv_cache:
352370
key_cache = key_cache.view(key_cache_og_dtype)
353371
value_cache = value_cache.view(value_cache_og_dtype)

0 commit comments

Comments
 (0)