|
30 | 30 | if current_platform.is_rocm():
|
31 | 31 | VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
|
32 | 32 | if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
|
33 |
| - from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache |
| 33 | + from aiter.ops.triton.fused_kv_cache import ( |
| 34 | + fused_qk_rope_reshape_and_cache) |
| 35 | + |
34 | 36 |
|
35 | 37 | @dataclass
|
36 | 38 | class TritonAttentionMetadata:
|
@@ -250,23 +252,24 @@ def __init__(
|
250 | 252 | "TritonAttentionImpl")
|
251 | 253 |
|
252 | 254 | self.fp8_dtype = current_platform.fp8_dtype()
|
253 |
| - self.force_prefill_decode_attn = \ |
254 |
| - envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION |
255 | 255 |
|
256 | 256 | # If not using prefill decode attention, we use the Triton
|
257 | 257 | # unified attention implementation.
|
258 | 258 | if use_aiter_unified_attention():
|
259 | 259 | logger.info_once(
|
260 | 260 | "Using aiter unified attention for TritonAttentionImpl")
|
261 |
| - from aiter.ops.triton.unified_attention import ( |
262 |
| - unified_attention) |
| 261 | + from aiter.ops.triton.unified_attention import unified_attention |
263 | 262 | self.unified_attention = unified_attention
|
264 |
| - elif not self.force_prefill_decode_attn: |
| 263 | + elif not envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION: |
265 | 264 | logger.info_once(
|
266 | 265 | "Using vllm unified attention for TritonAttentionImpl")
|
267 | 266 | from vllm.attention.ops.triton_unified_attention import (
|
268 | 267 | unified_attention)
|
269 | 268 | self.unified_attention = unified_attention
|
| 269 | + else: |
| 270 | + logger.info_once( |
| 271 | + "Using vllm split prefill decode attention for TritonAttentionImpl" |
| 272 | + ) |
270 | 273 |
|
271 | 274 | self.sinks = sinks
|
272 | 275 | if sinks is not None:
|
@@ -324,30 +327,45 @@ def forward(
|
324 | 327 | # Whenever making a change in this method, please benchmark the
|
325 | 328 | # performance to make sure it does not introduce any overhead.
|
326 | 329 |
|
327 |
| - use_prefill_decode_attn = self.force_prefill_decode_attn |
| 330 | + use_prefill_decode_attn = envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION \ |
| 331 | + and not use_aiter_unified_attention() |
328 | 332 | num_actual_tokens = attn_metadata.num_actual_tokens
|
329 | 333 |
|
330 | 334 | if use_prefill_decode_attn:
|
331 | 335 | key_cache, value_cache = PagedAttention.split_kv_cache(
|
332 | 336 | kv_cache, self.num_kv_heads, self.head_size)
|
333 | 337 | else:
|
334 | 338 | key_cache, value_cache = kv_cache.unbind(0)
|
335 |
| - |
| 339 | + |
336 | 340 | if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
|
337 |
| - assert self.kv_sharing_target_layer_name is None, "self.kv_sharing_target_layer_name error" |
338 |
| - cos, sin = cos_sin_cache.chunk(2, dim = -1) |
| 341 | + assert self.kv_sharing_target_layer_name is None, "self.kv_sharing_target_layer_name error" |
| 342 | + cos, sin = cos_sin_cache.chunk(2, dim=-1) |
339 | 343 | is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
|
340 | 344 | if is_fp8_kv_cache:
|
341 | 345 | key_cache_og_dtype = key_cache.dtype
|
342 | 346 | value_cache_og_dtype = value_cache.dtype
|
343 | 347 | key_cache = key_cache.view(self.fp8_dtype)
|
344 | 348 | value_cache = value_cache.view(self.fp8_dtype)
|
345 | 349 | query, key, key_cache, value_cache, output = fused_qk_rope_reshape_and_cache(
|
346 |
| - query, key, value, key_cache, value_cache, attn_metadata.slot_mapping, |
347 |
| - positions, cos, sin, |
348 |
| - layer._k_scale, layer._v_scale, |
349 |
| - is_neox, |
350 |
| - flash_layout=(not use_prefill_decode_attn), apply_scale=is_fp8_kv_cache, offs=None, q_out=query, k_out=key, output_zeros=True, zeros_out=output) |
| 350 | + query, |
| 351 | + key, |
| 352 | + value, |
| 353 | + key_cache, |
| 354 | + value_cache, |
| 355 | + attn_metadata.slot_mapping, |
| 356 | + positions, |
| 357 | + cos, |
| 358 | + sin, |
| 359 | + layer._k_scale, |
| 360 | + layer._v_scale, |
| 361 | + is_neox, |
| 362 | + flash_layout=(not use_prefill_decode_attn), |
| 363 | + apply_scale=is_fp8_kv_cache, |
| 364 | + offs=None, |
| 365 | + q_out=query, |
| 366 | + k_out=key, |
| 367 | + output_zeros=True, |
| 368 | + zeros_out=output) |
351 | 369 | if is_fp8_kv_cache:
|
352 | 370 | key_cache = key_cache.view(key_cache_og_dtype)
|
353 | 371 | value_cache = value_cache.view(value_cache_og_dtype)
|
|
0 commit comments