Skip to content

Commit 88f5dc9

Browse files
JartXtjtanaa
authored andcommitted
[FIXBUG ] Allow disabling rocm_aiter_fa backend for ROCm GPUs not compatible with AITER (vllm-project#22795)
Signed-off-by: JartX <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]>
1 parent 39c0aab commit 88f5dc9

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import ast
44
from dataclasses import replace
5-
from typing import Optional
5+
from importlib.util import find_spec
6+
from typing import Optional, Protocol
67

78
import numpy as np
89
import torch
@@ -20,8 +21,6 @@
2021
from vllm.platforms import current_platform
2122
from vllm.utils import is_pin_memory_available
2223
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
23-
from vllm.v1.attention.backends.rocm_aiter_fa import (
24-
AiterFlashAttentionMetadata)
2524
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
2625
TreeAttentionMetadataBuilder)
2726
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
@@ -34,6 +33,17 @@
3433
PADDING_SLOT_ID = -1
3534

3635

36+
class EagleAttentionMetadata(Protocol):
37+
# Required attributes
38+
num_actual_tokens: int
39+
max_query_len: int
40+
query_start_loc: torch.Tensor
41+
max_seq_len: int
42+
seq_lens: torch.Tensor
43+
block_table: torch.Tensor
44+
slot_mapping: torch.Tensor
45+
46+
3747
class EagleProposer:
3848

3949
def __init__(
@@ -97,6 +107,20 @@ def __init__(
97107
dtype=self.dtype,
98108
device=device)
99109

110+
# Determine allowed attention backends once during initialization.
111+
self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...]
112+
if current_platform.is_rocm():
113+
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
114+
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
115+
if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"):
116+
from vllm.v1.attention.backends.rocm_aiter_fa import (
117+
AiterFlashAttentionMetadata)
118+
rocm_types.append(AiterFlashAttentionMetadata)
119+
self.allowed_attn_types = tuple(rocm_types)
120+
else:
121+
self.allowed_attn_types = (FlashAttentionMetadata,
122+
TreeAttentionMetadata)
123+
100124
# Parse the speculative token tree.
101125
spec_token_tree = self.speculative_config.speculative_token_tree
102126
self.tree_choices: list[tuple[int,
@@ -165,7 +189,7 @@ def propose(
165189
for layer_name in self.attn_layer_names:
166190
per_layer_attn_metadata[layer_name] = attn_metadata
167191
if self.use_cuda_graph and \
168-
num_tokens <= self.cudagraph_batch_sizes[-1]:
192+
num_tokens <= self.cudagraph_batch_sizes[-1]:
169193
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
170194
else:
171195
num_input_tokens = num_tokens
@@ -225,25 +249,13 @@ def propose(
225249
# TODO: Currently, MTP module released by deepseek only has
226250
# one layer. Adapt this code to support multiple layers once
227251
# there's a multi-layer MTP module.
228-
229-
# On ROCm, both AiterFlashAttention and TritonAttention
230-
# support multi-token eagle spec decode.
231-
if current_platform.is_rocm():
232-
assert isinstance(
233-
attn_metadata,
234-
(TritonAttentionMetadata, AiterFlashAttentionMetadata,
235-
FlashAttentionMetadata))
236-
else:
237-
# Currently, only FlashAttention supports multi-token eagle spec
238-
# decode. This is because the code below makes assumptions about
239-
# attn_metadata attributes available.
240-
assert isinstance(attn_metadata, FlashAttentionMetadata)
252+
assert isinstance(attn_metadata, self.allowed_attn_types)
241253

242254
# Generate the remaining draft tokens.
243255
draft_token_ids_list = [draft_token_ids]
244256

245257
if self.use_cuda_graph and \
246-
batch_size <= self.cudagraph_batch_sizes[-1]:
258+
batch_size <= self.cudagraph_batch_sizes[-1]:
247259
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
248260
else:
249261
input_batch_size = batch_size
@@ -449,7 +461,7 @@ def propose_tree(
449461
num_tokens, -1)
450462

451463
if self.use_cuda_graph and \
452-
num_tokens <= self.cudagraph_batch_sizes[-1]:
464+
num_tokens <= self.cudagraph_batch_sizes[-1]:
453465
num_input_tokens = self.vllm_config.pad_for_cudagraph(
454466
num_tokens)
455467
else:
@@ -508,19 +520,19 @@ def prepare_inputs(
508520
"""
509521
# E.g.
510522
# common_attn_metadata.query_start_loc{_cpu}:
511-
# [0, q1, q1 + q2, q1 + q2 + q3]
523+
# [0, q1, q1 + q2, q1 + q2 + q3]
512524
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
513525
# num_rejected_tokens: [n1, n2, n3]
514526
# This function computes the intermediate values:
515527
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
516528
# And returns:
517529
# common_attn_metadata.query_start_loc{_cpu}:
518-
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
530+
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
519531
# common_attn_metadata.seq_lens{_cpu}:
520-
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
532+
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
521533
# token_indices: [0, 1, ..., q1 - n1 - 1,
522-
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
523-
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
534+
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
535+
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
524536

525537
device = common_attn_metadata.query_start_loc.device
526538
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
@@ -564,9 +576,9 @@ def prepare_inputs(
564576
old_query_start_locs_expanded = np.repeat(
565577
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
566578
# Final token indices are:
567-
# [0, 1, // req 1
568-
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
569-
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
579+
# [0, 1, // req 1
580+
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
581+
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
570582
token_indices_np = token_offests + old_query_start_locs_expanded
571583
token_indices = torch.from_numpy(token_indices_np).to(
572584
device, non_blocking=True)
@@ -616,20 +628,18 @@ def load_model(self, target_model: nn.Module) -> None:
616628
target_language_model = target_model
617629
# share embed_tokens with the target model if needed
618630
if get_pp_group().world_size == 1 \
619-
and self.model.model.embed_tokens.weight.shape \
620-
== target_language_model.model.embed_tokens.weight.shape:
631+
and self.model.model.embed_tokens.weight.shape \
632+
== target_language_model.model.embed_tokens.weight.shape:
621633
logger.info(
622-
"Assuming the EAGLE head shares the same vocab embedding" \
623-
" with the target model."
624-
)
634+
"Assuming the EAGLE head shares the same vocab embedding"
635+
" with the target model.")
625636
del self.model.model.embed_tokens
626637
self.model.model.embed_tokens = (
627638
target_language_model.model.embed_tokens)
628639
else:
629640
logger.info(
630-
"The EAGLE head's vocab embedding will be loaded separately" \
631-
" from the target model."
632-
)
641+
"The EAGLE head's vocab embedding will be loaded separately"
642+
" from the target model.")
633643

634644
# share lm_head with the target model if needed
635645
# some model definition do not define lm_head explicitly

0 commit comments

Comments
 (0)