|
20 | 20 | from torch._prims_common import DeviceLikeType |
21 | 21 |
|
22 | 22 | from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures |
23 | | -from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls |
| 23 | +from tensorrt_llm._torch.pyexecutor._util import ( |
| 24 | + _create_kv_cache_manager, |
| 25 | + get_decoding_mode, |
| 26 | + get_kv_cache_manager_cls, |
| 27 | +) |
24 | 28 | from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder |
25 | 29 | from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length |
26 | 30 | from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config |
|
30 | 34 | from tensorrt_llm.llmapi.llm_args import ( |
31 | 35 | ContextChunkingPolicy, |
32 | 36 | LoadFormat, |
| 37 | + SamplerType, |
33 | 38 | SpeculativeConfig, |
34 | 39 | TorchLlmArgs, |
35 | 40 | ) |
|
42 | 47 | from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine |
43 | 48 | from ...pyexecutor.py_executor import PyExecutor |
44 | 49 | from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType |
45 | | -from ...pyexecutor.sampler import TorchSampler |
| 50 | +from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler |
46 | 51 | from ...pyexecutor.scheduler import ( |
47 | 52 | BindCapacityScheduler, |
48 | 53 | BindMicroBatchScheduler, |
@@ -283,9 +288,9 @@ def __init__( |
283 | 288 | self.llm_args.batch_wait_timeout_iters = 0 |
284 | 289 | self.llm_args.batch_wait_max_tokens_ratio = 0.0 |
285 | 290 | self.llm_args.max_num_tokens = seq_info.max_num_tokens |
| 291 | + self.llm_args.max_seq_len = seq_info.max_seq_len |
286 | 292 | self.iter_counter = 0 |
287 | 293 | self.iter_states = {} |
288 | | - self.llm_args.max_seq_len = seq_info.max_seq_len |
289 | 294 |
|
290 | 295 | # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... |
291 | 296 | self.max_beam_width = max_beam_width |
@@ -487,6 +492,9 @@ def _compute_logits(self) -> List[torch.Tensor]: |
487 | 492 | # run the model |
488 | 493 | logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0] |
489 | 494 |
|
| 495 | + # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless. |
| 496 | + logits = logits.float() |
| 497 | + |
490 | 498 | # return a list of tensors |
491 | 499 | return self.cache_seq_interface.info.unnest_sequences(logits) |
492 | 500 |
|
@@ -574,6 +582,59 @@ def create_draft_model_engine_maybe( |
574 | 582 | return draft_model_engine |
575 | 583 |
|
576 | 584 |
|
| 585 | +class TRTLLMSamplerModelConfig: |
| 586 | + def __init__(self, vocab_size_padded: int): |
| 587 | + self.config = SimpleNamespace() |
| 588 | + self.config.vocab_size = vocab_size_padded |
| 589 | + |
| 590 | + # Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler. |
| 591 | + self.config.num_hidden_layers = 42 |
| 592 | + self.config.hidden_size = 42 |
| 593 | + self.config.num_attention_heads = 42 |
| 594 | + |
| 595 | + |
| 596 | +def instantiate_sampler( |
| 597 | + ad_config: LlmArgs, |
| 598 | + max_num_sequences: int, |
| 599 | + max_draft_len: int, |
| 600 | + max_total_draft_tokens: int, |
| 601 | + dist_mapping: Mapping, |
| 602 | + engine: ADEngine, |
| 603 | +): |
| 604 | + if ad_config.sampler_type == SamplerType.TorchSampler: |
| 605 | + # search sampler with speculative decoding |
| 606 | + sampler_args = TorchSampler.Args( |
| 607 | + max_seq_len=ad_config.max_seq_len, |
| 608 | + max_draft_len=max_draft_len, |
| 609 | + max_total_draft_tokens=max_total_draft_tokens, |
| 610 | + max_num_sequences=max_num_sequences, |
| 611 | + max_beam_width=ad_config.max_beam_width, |
| 612 | + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 613 | + ) |
| 614 | + sampler = TorchSampler(sampler_args) |
| 615 | + |
| 616 | + elif ad_config.sampler_type == SamplerType.TRTLLMSampler: |
| 617 | + vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded |
| 618 | + sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded) |
| 619 | + decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width) |
| 620 | + sampler = TRTLLMSampler( |
| 621 | + model=sampler_model_config, |
| 622 | + model_dtype=torch.bfloat16, # hardcoded as bfloat16; does not seem necessary in C++ code. |
| 623 | + mapping=dist_mapping, |
| 624 | + decoding_mode=decoding_mode, |
| 625 | + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 626 | + max_seq_len=ad_config.max_seq_len, |
| 627 | + max_batch_size=ad_config.max_batch_size, |
| 628 | + max_beam_width=ad_config.max_beam_width, |
| 629 | + decoding_config=ad_config.decoding_config, |
| 630 | + kv_cache_config=ad_config.kv_cache_config, |
| 631 | + ) |
| 632 | + else: |
| 633 | + raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.") |
| 634 | + |
| 635 | + return sampler |
| 636 | + |
| 637 | + |
577 | 638 | def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None): |
578 | 639 | """Create an AutoDeploy executor from the given configuration and tokenizer. |
579 | 640 | The tokenizer is required for guided decoding. |
@@ -695,23 +756,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer |
695 | 756 | ) |
696 | 757 | scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) |
697 | 758 |
|
698 | | - # search sampler with speculative decoding |
699 | | - sampler_args = TorchSampler.Args( |
700 | | - max_seq_len=ad_config.max_seq_len, |
| 759 | + vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded |
| 760 | + sampler = instantiate_sampler( |
| 761 | + ad_config=ad_config, |
| 762 | + max_num_sequences=max_num_sequences, |
701 | 763 | max_draft_len=max_draft_len, |
702 | 764 | max_total_draft_tokens=max_total_draft_tokens, |
703 | | - max_num_sequences=max_num_sequences, |
704 | | - max_beam_width=ad_config.max_beam_width, |
705 | | - disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 765 | + dist_mapping=dist_mapping, |
| 766 | + engine=engine, |
706 | 767 | ) |
707 | | - sampler = TorchSampler(sampler_args) |
708 | 768 |
|
709 | 769 | # Guided (structured) decoding. |
710 | 770 | guided_decoder = None |
711 | 771 | if ( |
712 | 772 | (guided_decoding_backend := ad_config.guided_decoding_backend) is not None |
713 | 773 | ) and dist_mapping.is_last_pp_rank(): |
714 | | - vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded |
715 | 774 | if vocab_size_padded is None: |
716 | 775 | raise RuntimeError( |
717 | 776 | "Could not determine the vocabulary size. Required for guided decoding." |
|
0 commit comments