Skip to content

Commit 506f798

Browse files
committed
Add some comments and remove unneeded condition
1 parent a29eb07 commit 506f798

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

vllm/attention/layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def unified_attention_with_output(
532532
# Not all layers can use RoPE fusing, so check that they were given all
533533
# needed inputs along with the environment variable to enable this.
534534
if (
535-
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
535+
536536
and hasattr(self.impl, "rotary_emb")
537537
and self.impl.rotary_emb is not None
538538
and positions is not None
@@ -542,6 +542,7 @@ def unified_attention_with_output(
542542
or isinstance(self.impl, AiterMLAImpl)
543543
)
544544
):
545+
assert VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE, f"Only expecting rotary_emb and positions when VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is True."
545546
# fusing RoPE with flushing kv_cache operation
546547
self.impl.forward(self,
547548
query,

vllm/model_executor/models/llama4.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,10 +209,12 @@ def __init__(self,
209209
extra_args = {}
210210
if use_chunked_local_attn:
211211
extra_args["attention_chunk_size"] = config.attention_chunk_size
212+
# Use the rotary_emb in attention only when it's supported
212213
self.use_fused_rope = (
213214
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
214215
and self.rotary_emb is not None
215216
and self.qk_norm is None
217+
and not self.attn_temperature_tuning
216218
)
217219
if self.use_fused_rope:
218220
extra_args["rotary_emb"] = self.rotary_emb
@@ -240,10 +242,10 @@ def forward(
240242
qkv, _ = self.qkv_proj(hidden_states)
241243
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
242244

243-
# For limited cases that match Llama3's behavior, use fused RoPE
245+
# rotary_emb is fused into self.attn in this case
244246
if self.use_fused_rope:
245247
assert not (
246-
self.attn_temperature_tuning and self.nope
248+
self.attn_temperature_tuning
247249
), f"{self.attn_temperature_tuning=} and {self.nope=} must be False with {VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE=}"
248250
attn_output = self.attn(q, k, v, positions=positions)
249251
output, _ = self.o_proj(attn_output)

0 commit comments

Comments
 (0)