Skip to content

Commit f155812

Browse files
authored
[TRTLLM-6756][feat] Add Beam Search to TorchSampler (#8509)
Signed-off-by: Stefan Niebler <[email protected]>
1 parent b024040 commit f155812

File tree

10 files changed

+2156
-462
lines changed

10 files changed

+2156
-462
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
436436
max_total_draft_tokens=max_total_draft_tokens,
437437
max_num_sequences=max_num_sequences,
438438
max_beam_width=ad_config.max_beam_width,
439+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
439440
)
440441
sampler = TorchSampler(sampler_args)
441442

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,11 +824,16 @@ def create_py_executor_instance(
824824
virtual_memory_pools=virtual_memory_pools)
825825

826826

827-
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
828-
max_batch_size: int,
829-
speculative_config: SpeculativeConfig,
830-
max_beam_width: int,
831-
disable_flashinfer_sampling: bool):
827+
def create_torch_sampler_args(
828+
mapping: Mapping,
829+
*,
830+
max_seq_len: int,
831+
max_batch_size: int,
832+
speculative_config: SpeculativeConfig,
833+
max_beam_width: int,
834+
disable_overlap_scheduler: bool,
835+
disable_flashinfer_sampling: bool,
836+
):
832837
max_num_sequences = max_batch_size * mapping.pp_size
833838
max_draft_len = (0 if speculative_config is None else
834839
speculative_config.max_draft_len)
@@ -842,7 +847,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
842847
max_num_sequences=max_num_sequences,
843848
max_beam_width=max_beam_width,
844849
disable_flashinfer_sampling=disable_flashinfer_sampling,
845-
)
850+
disable_overlap_scheduler=disable_overlap_scheduler)
846851

847852

848853
def instantiate_sampler(
@@ -865,6 +870,7 @@ def instantiate_sampler(
865870
max_batch_size=max_batch_size,
866871
speculative_config=speculative_config,
867872
max_beam_width=max_beam_width,
873+
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,
868874
disable_flashinfer_sampling=disable_flashinfer_sampling,
869875
)
870876
decoding_mode = get_decoding_mode(decoding_config=decoding_config,

0 commit comments

Comments
 (0)