Skip to content

Commit a29eb07

Browse files
committed
Add limited Llama4 support for VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
This can be used with VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=1
1 parent 2b4cb8a commit a29eb07

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

vllm/attention/layer.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,20 @@ def unified_attention_with_output(
529529
from vllm.v1.attention.backends.triton_attn import TritonAttentionImpl
530530
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl
531531
from vllm.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAImpl
532-
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and (isinstance(self.impl, TritonAttentionImpl) or isinstance(self.impl, AiterFlashAttentionImpl) or isinstance(self.impl, AiterMLAImpl)):
532+
# Not all layers can use RoPE fusing, so check that they were given all
533+
# needed inputs along with the environment variable to enable this.
534+
if (
535+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
536+
and hasattr(self.impl, "rotary_emb")
537+
and self.impl.rotary_emb is not None
538+
and positions is not None
539+
and (
540+
isinstance(self.impl, TritonAttentionImpl)
541+
or isinstance(self.impl, AiterFlashAttentionImpl)
542+
or isinstance(self.impl, AiterMLAImpl)
543+
)
544+
):
533545
# fusing RoPE with flushing kv_cache operation
534-
assert hasattr(self.impl, "rotary_emb") and self.impl.rotary_emb is not None and positions is not None, f"rotary_emb not found in {self.impl=} and positions cannot be None"
535546
self.impl.forward(self,
536547
query,
537548
key,

vllm/attention/layers/chunked_local_attention.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def __init__(self,
5858
cache_config: Optional[CacheConfig] = None,
5959
quant_config: Optional[QuantizationConfig] = None,
6060
kv_sharing_target_layer_name: Optional[str] = None,
61-
prefix: str = ""):
61+
prefix: str = "",
62+
**kwargs):
6263
dtype = torch.get_default_dtype()
6364
if cache_config is not None:
6465
kv_cache_dtype = cache_config.cache_dtype
@@ -88,4 +89,5 @@ def __init__(self,
8889
quant_config=quant_config,
8990
prefix=prefix,
9091
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
91-
attn_backend=attn_backend)
92+
attn_backend=attn_backend,
93+
**kwargs)

vllm/model_executor/models/llama4.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch import nn
2525
from transformers import Llama4TextConfig
2626

27+
import vllm.envs as envs
2728
from vllm.attention import Attention
2829
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
2930
from vllm.compilation.decorators import support_torch_compile
@@ -38,12 +39,19 @@
3839
from vllm.model_executor.layers.rotary_embedding import get_rope
3940
from vllm.model_executor.model_loader.weight_utils import (
4041
default_weight_loader, maybe_remap_kv_scale_name)
42+
from vllm.platforms import current_platform
4143

4244
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
4345
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
4446
is_pp_missing_parameter)
4547

4648

49+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
50+
current_platform.is_rocm()
51+
and envs.VLLM_ROCM_USE_AITER
52+
and envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
53+
)
54+
4755
class Llama4MoE(nn.Module):
4856

4957
@staticmethod
@@ -198,6 +206,16 @@ def __init__(self,
198206
use_chunked_local_attn = not self.nope and config.attention_chunk_size
199207
attn_cls = (ChunkedLocalAttention
200208
if use_chunked_local_attn else Attention)
209+
extra_args = {}
210+
if use_chunked_local_attn:
211+
extra_args["attention_chunk_size"] = config.attention_chunk_size
212+
self.use_fused_rope = (
213+
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
214+
and self.rotary_emb is not None
215+
and self.qk_norm is None
216+
)
217+
if self.use_fused_rope:
218+
extra_args["rotary_emb"] = self.rotary_emb
201219
self.attn = attn_cls(
202220
self.num_heads,
203221
self.head_dim,
@@ -206,9 +224,7 @@ def __init__(self,
206224
cache_config=cache_config,
207225
quant_config=quant_config,
208226
prefix=f"{prefix}.attn",
209-
**({
210-
"attention_chunk_size": config.attention_chunk_size
211-
} if use_chunked_local_attn else {}))
227+
**extra_args)
212228

213229
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
214230
floor = torch.floor((positions + 1.0) / self.floor_scale)
@@ -224,6 +240,15 @@ def forward(
224240
qkv, _ = self.qkv_proj(hidden_states)
225241
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
226242

243+
# For limited cases that match Llama3's behavior, use fused RoPE
244+
if self.use_fused_rope:
245+
assert not (
246+
self.attn_temperature_tuning and self.nope
247+
), f"{self.attn_temperature_tuning=} and {self.nope=} must be False with {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}"
248+
attn_output = self.attn(q, k, v, positions=positions)
249+
output, _ = self.o_proj(attn_output)
250+
return output
251+
227252
if self.rotary_emb is not None:
228253
q, k = self.rotary_emb(positions, q, k)
229254

0 commit comments

Comments
 (0)