2424from torch import nn
2525from transformers import Llama4TextConfig
2626
27+ import vllm .envs as envs
2728from vllm .attention import Attention
2829from vllm .attention .layers .chunked_local_attention import ChunkedLocalAttention
2930from vllm .compilation .decorators import support_torch_compile
3839from vllm .model_executor .layers .rotary_embedding import get_rope
3940from vllm .model_executor .model_loader .weight_utils import (
4041 default_weight_loader , maybe_remap_kv_scale_name )
42+ from vllm .platforms import current_platform
4143
4244from .llama import LlamaForCausalLM , LlamaMLP , LlamaModel
4345from .utils import (AutoWeightsLoader , extract_layer_index , fast_topk ,
4446 is_pp_missing_parameter )
4547
4648
49+ VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
50+ current_platform .is_rocm ()
51+ and envs .VLLM_ROCM_USE_AITER
52+ and envs .VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
53+ )
54+
4755class Llama4MoE (nn .Module ):
4856
4957 @staticmethod
@@ -198,6 +206,16 @@ def __init__(self,
198206 use_chunked_local_attn = not self .nope and config .attention_chunk_size
199207 attn_cls = (ChunkedLocalAttention
200208 if use_chunked_local_attn else Attention )
209+ extra_args = {}
210+ if use_chunked_local_attn :
211+ extra_args ["attention_chunk_size" ] = config .attention_chunk_size
212+ self .use_fused_rope = (
213+ VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
214+ and self .rotary_emb is not None
215+ and self .qk_norm is None
216+ )
217+ if self .use_fused_rope :
218+ extra_args ["rotary_emb" ] = self .rotary_emb
201219 self .attn = attn_cls (
202220 self .num_heads ,
203221 self .head_dim ,
@@ -206,9 +224,7 @@ def __init__(self,
206224 cache_config = cache_config ,
207225 quant_config = quant_config ,
208226 prefix = f"{ prefix } .attn" ,
209- ** ({
210- "attention_chunk_size" : config .attention_chunk_size
211- } if use_chunked_local_attn else {}))
227+ ** extra_args )
212228
213229 def _get_attn_scale (self , positions : torch .Tensor ) -> torch .Tensor :
214230 floor = torch .floor ((positions + 1.0 ) / self .floor_scale )
@@ -224,6 +240,15 @@ def forward(
224240 qkv , _ = self .qkv_proj (hidden_states )
225241 q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
226242
243+ # For limited cases that match Llama3's behavior, use fused RoPE
244+ if self .use_fused_rope :
245+ assert not (
246+ self .attn_temperature_tuning and self .nope
247+ ), f"{ self .attn_temperature_tuning = } and { self .nope = } must be False with { VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = } "
248+ attn_output = self .attn (q , k , v , positions = positions )
249+ output , _ = self .o_proj (attn_output )
250+ return output
251+
227252 if self .rotary_emb is not None :
228253 q , k = self .rotary_emb (positions , q , k )
229254
0 commit comments