|
7 | 7 | from tensorrt_llm.mapping import Mapping |
8 | 8 |
|
9 | 9 | from ..attention_backend import AttentionMetadata |
10 | | -from ..pyexecutor.guided_decoder import CapturableGuidedDecoder |
11 | 10 | from ..pyexecutor.llm_request import LlmRequest |
12 | 11 | from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager |
13 | 12 | from ..pyexecutor.sampler import TorchSampler |
14 | 13 | from ..pyexecutor.scheduler import ScheduledRequests |
15 | | -from .interface import SpecMetadata, get_force_num_accepted_tokens |
| 14 | +from .interface import SpecMetadata, SpecWorkerBase |
16 | 15 | from .mtp import MTPSampler |
17 | | -from .one_model_sampler import sampling_batch_spec_dec_one_model |
18 | 16 | from .spec_tree_manager import SpecTreeManager |
19 | 17 |
|
20 | 18 | if TYPE_CHECKING: |
@@ -358,15 +356,16 @@ def __init__(self, args: TorchSampler.Args): |
358 | 356 | super().__init__(args, nextn=args.max_draft_len) |
359 | 357 |
|
360 | 358 |
|
361 | | -class Eagle3OneModelWorker(nn.Module): |
| 359 | +class Eagle3OneModelWorker(SpecWorkerBase): |
362 | 360 |
|
363 | 361 | def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): |
364 | 362 | super().__init__() |
365 | 363 | self.spec_config = spec_config |
366 | | - self.max_draft_len = self.spec_config.max_draft_len |
367 | 364 | self.mapping = mapping |
368 | | - self.guided_decoder: Optional[CapturableGuidedDecoder] = None |
369 | | - self.force_num_accepted_tokens = get_force_num_accepted_tokens() |
| 365 | + |
| 366 | + @property |
| 367 | + def max_draft_len(self) -> int: |
| 368 | + return self.spec_config.max_draft_len |
370 | 369 |
|
371 | 370 | # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 |
372 | 371 | # @torch.compile(options={"max-autotune": True}) |
@@ -494,40 +493,6 @@ def forward(self, input_ids, position_ids, hidden_states, logits, |
494 | 493 | 'next_new_tokens': next_new_tokens, |
495 | 494 | } |
496 | 495 |
|
497 | | - def _sample_tokens_for_batch( |
498 | | - self, |
499 | | - logits: torch.Tensor, |
500 | | - spec_metadata: Eagle3OneModelSpecMetadata, |
501 | | - num_contexts: int, |
502 | | - batch_size: int, |
503 | | - ) -> torch.Tensor: |
504 | | - """ |
505 | | - Sample tokens from logits using per-request sampling parameters. |
506 | | - Supports both greedy and non-greedy sampling. |
507 | | -
|
508 | | - Args: |
509 | | - logits: [num_tokens, vocab_size] - Logits to sample from |
510 | | - spec_metadata: Metadata containing sampling parameters |
511 | | - batch_size: Number of requests in the batch |
512 | | -
|
513 | | - Returns: |
514 | | - sampled_tokens: [num_tokens] - Sampled token ids |
515 | | - """ |
516 | | - if spec_metadata.allow_advanced_sampling: |
517 | | - num_gens = batch_size - num_contexts |
518 | | - num_tokens = num_contexts + num_gens * (self.max_draft_len + 1) |
519 | | - |
520 | | - temperatures = spec_metadata.temperatures[:num_tokens] |
521 | | - top_ks = spec_metadata.top_ks[:num_tokens] |
522 | | - top_ps = spec_metadata.top_ps[:num_tokens] |
523 | | - |
524 | | - sampled_tokens = sampling_batch_spec_dec_one_model( |
525 | | - logits, temperatures, top_ks, top_ps) |
526 | | - else: |
527 | | - sampled_tokens = torch.argmax(logits, dim=-1) |
528 | | - |
529 | | - return sampled_tokens |
530 | | - |
531 | 496 | def sample_and_accept_draft_tokens( |
532 | 497 | self, |
533 | 498 | logits: torch.Tensor, |
@@ -578,7 +543,7 @@ def draft_decoder( |
578 | 543 | draft_model: nn.Module, |
579 | 544 | ): |
580 | 545 | ''' |
581 | | - Sampling draft tokens. |
| 546 | + Sampling draft tokens with support for non-greedy sampling. |
582 | 547 |
|
583 | 548 | Args: |
584 | 549 | logits: torch.Tensor |
@@ -649,8 +614,3 @@ def prepare_1st_drafter_inputs( |
649 | 614 | "attn_metadata": attn_metadata, |
650 | 615 | "spec_metadata": spec_metadata, |
651 | 616 | } |
652 | | - |
653 | | - def set_guided_decoder(self, |
654 | | - guided_decoder: CapturableGuidedDecoder) -> bool: |
655 | | - self.guided_decoder = guided_decoder |
656 | | - return True |
0 commit comments