Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
USE_XFORMERS_OPS = None

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE

Check failure on line 33 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:33:81: E501 Line too long (116 > 80)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")

Check failure on line 37 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (G004)

vllm/attention/layer.py:37:13: G004 Logging statement uses f-string

def check_xformers_availability():
global USE_XFORMERS_OPS
Expand Down Expand Up @@ -527,8 +527,9 @@
kv_cache = self.kv_cache[forward_context.virtual_engine]

from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
# fusing RoPE with flushing kv_cache operation
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None"
self.impl.forward(self,
Expand All @@ -536,13 +537,13 @@
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,

Check failure on line 542 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM101)

vllm/attention/layer.py:540:13: SIM101 Multiple `isinstance` calls for expression, merge into a single call
positions=positions)
else:
assert positions is None, f"positions must be None {positions=}"
self.impl.forward(self,

Check failure on line 546 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/attention/layer.py:546:81: E501 Line too long (142 > 80)
query,
key,
value,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD:
from aiter.ops.triton.fused_add_rmsnorm_pad import fused_add_rmsnorm_pad
Expand All @@ -51,7 +51,8 @@
VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM = False

VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ADD_RMSNORM_PAD=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_BF16_GEMM=}")

Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@

if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
from vllm.model_executor.layers.activation import VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and not envs.VLLM_ROCM_USE_AITER_MHA
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
else:
VLLM_ROCM_USE_AITER_TRITON_SILU_MUL_FP8_QUANT = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False

VLLM_ROCM_USE_AITER_MHA = envs.VLLM_ROCM_USE_AITER_MHA
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=} {VLLM_ROCM_USE_AITER_MHA=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}")
logger.info(f"[Aiter] {VLLM_ROCM_USE_AITER_MHA=}")

class LlamaMLP(nn.Module):

Expand Down
101 changes: 81 additions & 20 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,24 @@
AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 256

if current_platform.is_rocm():
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
VLLM_USE_AITER_TRITON_ROPE = envs.VLLM_USE_AITER_TRITON_ROPE
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
VLLM_USE_AITER_TRITON_ROPE = False

if current_platform.is_rocm():
import aiter

from vllm.triton_utils import tl, triton
Expand Down Expand Up @@ -209,8 +223,6 @@ def flash_attn_varlen_func_fake(
flash_attn_varlen_func_fake,
dispatch_key=current_platform.dispatch_key)

logger = init_logger(__name__)


@dataclass
class AiterFlashAttentionMetadata:
Expand Down Expand Up @@ -420,6 +432,8 @@ def __init__(
if self.sinks is not None:
raise NotImplementedError("Sinks are not supported for ROCM AITER")

self.fp8_dtype = current_platform.fp8_dtype()

def forward(
self,
layer: torch.nn.Module,
Expand All @@ -430,6 +444,7 @@ def forward(
attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
positions: torch.Tensor = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.
Expand Down Expand Up @@ -469,24 +484,70 @@ def forward(

num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
if positions is not None and query.shape[0] <= 256:
assert (
self.kv_sharing_target_layer_name is None
), "self.kv_sharing_target_layer_name cannot be None"
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
cos_sin_cache = self.rotary_emb.cos_sin_cache
is_neox = self.rotary_emb.is_neox_style
cos, sin = cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache_og_dtype = key_cache.dtype
value_cache_og_dtype = value_cache.dtype
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
is_neox,
flash_layout=True,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
if is_fp8_kv_cache:
key_cache = key_cache.view(key_cache_og_dtype)
value_cache = value_cache.view(value_cache_og_dtype)
else:
if positions is not None:
if VLLM_USE_AITER_TRITON_ROPE:
query, key = self.rotary_emb.forward_cuda(positions, query, key)
else:
query, key = self.rotary_emb(positions, query, key)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

if self.kv_cache_dtype.startswith("fp8"):
if current_platform.is_fp8_fnuz():
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ def forward(
output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None,
positions: torch.Tensor = None,
cos_sin_cache: torch.Tensor = None,
is_neox: bool = False,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
Loading