diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 3697275a7b64..41871922a109 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -527,8 +527,9 @@ def unified_attention_with_output( kv_cache = self.kv_cache[forward_context.virtual_engine] from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl + from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl - if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterMLAImpl)): + if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)): # fusing RoPE with flushing kv_cache operation assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None" self.impl.forward(self, diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a652d4f47b62..56f680e212a9 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -39,7 +39,7 @@ if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM - VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD if VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD: from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad @@ -51,7 +51,8 @@ VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = False VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA -logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}") +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}") +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}") logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD=}") logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM=}") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3a185296a93c..1dc51fcc00ea 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -63,13 +63,14 @@ if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT - VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE else: VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA -logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}") +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}") +logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}") class LlamaMLP(nn.Module): diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 5b56f0493860..999f94546bb3 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -15,10 +15,24 @@ AttentionMetadataBuilder, CommonAttentionMetadata) from vllm.v1.kv_cache_interface import AttentionSpec +from vllm import envs + +logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 256 -if current_platform.is_rocm(): +if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER: + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = ( + envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE + ) + VLLM_USE_AITER_TRITON_ROPE = envs.VLLM_USE_AITER_TRITON_ROPE + if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: + from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache +else: + VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False + VLLM_USE_AITER_TRITON_ROPE = False + +if current_platform.is_rocm(): import aiter from vllm.triton_utils import tl, triton @@ -209,8 +223,6 @@ def flash_attn_varlen_func_fake( flash_attn_varlen_func_fake, dispatch_key=current_platform.dispatch_key) -logger = init_logger(__name__) - @dataclass class AiterFlashAttentionMetadata: @@ -420,6 +432,8 @@ def __init__( if self.sinks is not None: raise NotImplementedError("Sinks are not supported for ROCM AITER") + self.fp8_dtype = current_platform.fp8_dtype() + def forward( self, layer: torch.nn.Module, @@ -430,6 +444,7 @@ def forward( attn_metadata: AiterFlashAttentionMetadata, output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, + positions: torch.Tensor = None, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with AiterFlashAttention. @@ -469,24 +484,70 @@ def forward( num_actual_tokens = attn_metadata.num_actual_tokens key_cache, value_cache = kv_cache.unbind(0) - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, + if positions is not None and query.shape[0] <= 256: + assert ( + self.kv_sharing_target_layer_name is None + ), "self.kv_sharing_target_layer_name cannot be None" + assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}" + cos_sin_cache = self.rotary_emb.cos_sin_cache + is_neox = self.rotary_emb.is_neox_style + cos, sin = cos_sin_cache.chunk(2, dim=-1) + is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") + if is_fp8_kv_cache: + key_cache_og_dtype = key_cache.dtype + value_cache_og_dtype = value_cache.dtype + key_cache = key_cache.view(self.fp8_dtype) + value_cache = value_cache.view(self.fp8_dtype) + query, key, key_cache, value_cache, output = ( + fused_qk_rope_reshape_and_cache( + query, + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + positions, + cos, + sin, + layer._k_scale, + layer._v_scale, + is_neox, + flash_layout=True, + apply_scale=is_fp8_kv_cache, + offs=None, + q_out=query, + k_out=key, + output_zeros=True, + zeros_out=output, + ) ) + if is_fp8_kv_cache: + key_cache = key_cache.view(key_cache_og_dtype) + value_cache = value_cache.view(value_cache_og_dtype) + else: + if positions is not None: + if VLLM_USE_AITER_TRITON_ROPE: + query, key = self.rotary_emb.forward_cuda(positions, query, key) + else: + query, key = self.rotary_emb(positions, query, key) + if self.kv_sharing_target_layer_name is None: + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) if self.kv_cache_dtype.startswith("fp8"): if current_platform.is_fp8_fnuz(): diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 533eb1cc5f79..0cfdb41d1f17 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -295,8 +295,6 @@ def forward( output: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None, positions: torch.Tensor = None, - cos_sin_cache: torch.Tensor = None, - is_neox: bool = False, output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention.