|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 | 3 | import ast
|
4 | 4 | from dataclasses import replace
|
5 |
| -from typing import Optional |
| 5 | +from importlib.util import find_spec |
| 6 | +from typing import Optional, Protocol |
6 | 7 |
|
7 | 8 | import numpy as np
|
8 | 9 | import torch
|
|
20 | 21 | from vllm.platforms import current_platform
|
21 | 22 | from vllm.utils import is_pin_memory_available
|
22 | 23 | from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
23 |
| -from vllm.v1.attention.backends.rocm_aiter_fa import ( |
24 |
| - AiterFlashAttentionMetadata) |
25 | 24 | from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
26 | 25 | TreeAttentionMetadataBuilder)
|
27 | 26 | from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
|
34 | 33 | PADDING_SLOT_ID = -1
|
35 | 34 |
|
36 | 35 |
|
| 36 | +class EagleAttentionMetadata(Protocol): |
| 37 | + # Required attributes |
| 38 | + num_actual_tokens: int |
| 39 | + max_query_len: int |
| 40 | + query_start_loc: torch.Tensor |
| 41 | + max_seq_len: int |
| 42 | + seq_lens: torch.Tensor |
| 43 | + block_table: torch.Tensor |
| 44 | + slot_mapping: torch.Tensor |
| 45 | + |
| 46 | + |
37 | 47 | class EagleProposer:
|
38 | 48 |
|
39 | 49 | def __init__(
|
@@ -97,6 +107,20 @@ def __init__(
|
97 | 107 | dtype=self.dtype,
|
98 | 108 | device=device)
|
99 | 109 |
|
| 110 | + # Determine allowed attention backends once during initialization. |
| 111 | + self.allowed_attn_types: tuple[type[EagleAttentionMetadata], ...] |
| 112 | + if current_platform.is_rocm(): |
| 113 | + rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] |
| 114 | + # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend |
| 115 | + if find_spec("vllm.v1.attention.backends.rocm_aiter_fa"): |
| 116 | + from vllm.v1.attention.backends.rocm_aiter_fa import ( |
| 117 | + AiterFlashAttentionMetadata) |
| 118 | + rocm_types.append(AiterFlashAttentionMetadata) |
| 119 | + self.allowed_attn_types = tuple(rocm_types) |
| 120 | + else: |
| 121 | + self.allowed_attn_types = (FlashAttentionMetadata, |
| 122 | + TreeAttentionMetadata) |
| 123 | + |
100 | 124 | # Parse the speculative token tree.
|
101 | 125 | spec_token_tree = self.speculative_config.speculative_token_tree
|
102 | 126 | self.tree_choices: list[tuple[int,
|
@@ -165,7 +189,7 @@ def propose(
|
165 | 189 | for layer_name in self.attn_layer_names:
|
166 | 190 | per_layer_attn_metadata[layer_name] = attn_metadata
|
167 | 191 | if self.use_cuda_graph and \
|
168 |
| - num_tokens <= self.cudagraph_batch_sizes[-1]: |
| 192 | + num_tokens <= self.cudagraph_batch_sizes[-1]: |
169 | 193 | num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
170 | 194 | else:
|
171 | 195 | num_input_tokens = num_tokens
|
@@ -225,25 +249,13 @@ def propose(
|
225 | 249 | # TODO: Currently, MTP module released by deepseek only has
|
226 | 250 | # one layer. Adapt this code to support multiple layers once
|
227 | 251 | # there's a multi-layer MTP module.
|
228 |
| - |
229 |
| - # On ROCm, both AiterFlashAttention and TritonAttention |
230 |
| - # support multi-token eagle spec decode. |
231 |
| - if current_platform.is_rocm(): |
232 |
| - assert isinstance( |
233 |
| - attn_metadata, |
234 |
| - (TritonAttentionMetadata, AiterFlashAttentionMetadata, |
235 |
| - FlashAttentionMetadata)) |
236 |
| - else: |
237 |
| - # Currently, only FlashAttention supports multi-token eagle spec |
238 |
| - # decode. This is because the code below makes assumptions about |
239 |
| - # attn_metadata attributes available. |
240 |
| - assert isinstance(attn_metadata, FlashAttentionMetadata) |
| 252 | + assert isinstance(attn_metadata, self.allowed_attn_types) |
241 | 253 |
|
242 | 254 | # Generate the remaining draft tokens.
|
243 | 255 | draft_token_ids_list = [draft_token_ids]
|
244 | 256 |
|
245 | 257 | if self.use_cuda_graph and \
|
246 |
| - batch_size <= self.cudagraph_batch_sizes[-1]: |
| 258 | + batch_size <= self.cudagraph_batch_sizes[-1]: |
247 | 259 | input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
248 | 260 | else:
|
249 | 261 | input_batch_size = batch_size
|
@@ -449,7 +461,7 @@ def propose_tree(
|
449 | 461 | num_tokens, -1)
|
450 | 462 |
|
451 | 463 | if self.use_cuda_graph and \
|
452 |
| - num_tokens <= self.cudagraph_batch_sizes[-1]: |
| 464 | + num_tokens <= self.cudagraph_batch_sizes[-1]: |
453 | 465 | num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
454 | 466 | num_tokens)
|
455 | 467 | else:
|
@@ -508,19 +520,19 @@ def prepare_inputs(
|
508 | 520 | """
|
509 | 521 | # E.g.
|
510 | 522 | # common_attn_metadata.query_start_loc{_cpu}:
|
511 |
| - # [0, q1, q1 + q2, q1 + q2 + q3] |
| 523 | + # [0, q1, q1 + q2, q1 + q2 + q3] |
512 | 524 | # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
513 | 525 | # num_rejected_tokens: [n1, n2, n3]
|
514 | 526 | # This function computes the intermediate values:
|
515 | 527 | # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
516 | 528 | # And returns:
|
517 | 529 | # common_attn_metadata.query_start_loc{_cpu}:
|
518 |
| - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] |
| 530 | + # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] |
519 | 531 | # common_attn_metadata.seq_lens{_cpu}:
|
520 |
| - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] |
| 532 | + # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] |
521 | 533 | # token_indices: [0, 1, ..., q1 - n1 - 1,
|
522 |
| - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, |
523 |
| - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] |
| 534 | + # q1, q1 + 1, ..., q1 + q2 - n2 - 1, |
| 535 | + # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] |
524 | 536 |
|
525 | 537 | device = common_attn_metadata.query_start_loc.device
|
526 | 538 | query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
@@ -564,9 +576,9 @@ def prepare_inputs(
|
564 | 576 | old_query_start_locs_expanded = np.repeat(
|
565 | 577 | query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
566 | 578 | # Final token indices are:
|
567 |
| - # [0, 1, // req 1 |
568 |
| - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 |
569 |
| - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 |
| 579 | + # [0, 1, // req 1 |
| 580 | + # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 |
| 581 | + # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 |
570 | 582 | token_indices_np = token_offests + old_query_start_locs_expanded
|
571 | 583 | token_indices = torch.from_numpy(token_indices_np).to(
|
572 | 584 | device, non_blocking=True)
|
@@ -616,20 +628,18 @@ def load_model(self, target_model: nn.Module) -> None:
|
616 | 628 | target_language_model = target_model
|
617 | 629 | # share embed_tokens with the target model if needed
|
618 | 630 | if get_pp_group().world_size == 1 \
|
619 |
| - and self.model.model.embed_tokens.weight.shape \ |
620 |
| - == target_language_model.model.embed_tokens.weight.shape: |
| 631 | + and self.model.model.embed_tokens.weight.shape \ |
| 632 | + == target_language_model.model.embed_tokens.weight.shape: |
621 | 633 | logger.info(
|
622 |
| - "Assuming the EAGLE head shares the same vocab embedding" \ |
623 |
| - " with the target model." |
624 |
| - ) |
| 634 | + "Assuming the EAGLE head shares the same vocab embedding" |
| 635 | + " with the target model.") |
625 | 636 | del self.model.model.embed_tokens
|
626 | 637 | self.model.model.embed_tokens = (
|
627 | 638 | target_language_model.model.embed_tokens)
|
628 | 639 | else:
|
629 | 640 | logger.info(
|
630 |
| - "The EAGLE head's vocab embedding will be loaded separately" \ |
631 |
| - " from the target model." |
632 |
| - ) |
| 641 | + "The EAGLE head's vocab embedding will be loaded separately" |
| 642 | + " from the target model.") |
633 | 643 |
|
634 | 644 | # share lm_head with the target model if needed
|
635 | 645 | # some model definition do not define lm_head explicitly
|
|
0 commit comments