Skip to content

Commit 58a8a8f

Browse files
feature: unify new_tokens format sample state to trtllm sampler new_tokens format (NVIDIA#4401)
Signed-off-by: Netanel Haber <[email protected]>
1 parent ebadc13 commit 58a8a8f

File tree

12 files changed

+410
-504
lines changed

12 files changed

+410
-504
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._prims_common import DeviceLikeType
66

7+
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
78
from tensorrt_llm._utils import nvtx_range
89

910
from ...._utils import mpi_rank, mpi_world_size
@@ -12,6 +13,7 @@
1213
from ....llmapi.llm_args import _AutoDeployLlmArgs
1314
from ....mapping import Mapping
1415
from ...distributed import MPIDist
16+
from ...pyexecutor._util import create_torch_sampler_args
1517
from ...pyexecutor.config import PyTorchConfig
1618
from ...pyexecutor.model_engine import ModelEngine
1719
from ...pyexecutor.py_executor import PyExecutor
@@ -292,7 +294,13 @@ def create_autodeploy_executor(
292294
max_seq_len=max_seq_len,
293295
max_batch_size=max_batch_size,
294296
)
295-
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
297+
seq_slot_manager = SeqSlotManager(max_num_sequences=max_batch_size * dist_mapping.pp_size)
298+
resource_manager = ResourceManager(
299+
{
300+
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
301+
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
302+
}
303+
)
296304
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
297305

298306
# scheduling
@@ -303,15 +311,17 @@ def create_autodeploy_executor(
303311
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
304312

305313
# search sampler with speculative decoding
306-
sampler = TorchSampler(max_seq_len=max_seq_len)
307-
308-
# creating the executor object
314+
sampler_args = create_torch_sampler_args(
315+
executor_config, dist_mapping, mixed_sampler=False, max_seq_len=max_seq_len
316+
)
317+
sampler = TorchSampler(sampler_args)
309318
py_executor = PyExecutor(
310319
resource_manager,
311320
scheduler,
312321
model_engine=engine,
313322
sampler=sampler,
314323
dist=mpi_dist,
324+
max_num_sequences=ad_config.max_batch_size * dist_mapping.pp_size,
315325
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
316326
max_input_len=ad_config.max_input_len,
317327
max_batch_size=ad_config.max_batch_size,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
2727
PeftCacheManager, ResourceManager,
2828
ResourceManagerType)
29-
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
30-
TRTLLMSampler)
29+
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
3130
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
3231
SimpleScheduler)
3332
from .seq_slot_manager import SeqSlotManager
@@ -506,6 +505,7 @@ def create_py_executor_instance(
506505
model_engine=model_engine,
507506
sampler=sampler,
508507
dist=dist,
508+
max_num_sequences=max_num_sequences,
509509
disable_overlap_scheduler=pytorch_backend_config.
510510
disable_overlap_scheduler,
511511
max_batch_size=executor_config.max_batch_size,
@@ -517,31 +517,44 @@ def create_py_executor_instance(
517517
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
518518

519519

520-
def instantiate_sampler(model_engine: PyTorchModelEngine,
520+
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
521+
*, max_seq_len: int, mixed_sampler: bool):
522+
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
523+
max_draft_tokens = (0 if executor_config.speculative_config is None else
524+
executor_config.speculative_config.max_draft_tokens)
525+
return TorchSampler.Args(
526+
max_seq_len=max_seq_len,
527+
max_draft_tokens=max_draft_tokens,
528+
max_num_sequences=max_num_sequences,
529+
max_beam_width=executor_config.max_beam_width,
530+
mixed_sampler=mixed_sampler,
531+
)
532+
533+
534+
def instantiate_sampler(engine: PyTorchModelEngine,
521535
executor_config: ExecutorConfig,
522536
pytorch_backend_config: PyTorchConfig,
523537
mapping: Mapping):
538+
sampler_args = create_torch_sampler_args(
539+
executor_config,
540+
mapping,
541+
max_seq_len=engine.max_seq_len,
542+
mixed_sampler=pytorch_backend_config.mixed_sampler)
524543
if mapping.cp_config.get('cp_type') == 'star_attention':
525544
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
526-
sampler = TorchStarAttentionSampler(
527-
max_seq_len=model_engine.max_seq_len)
528-
elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder(
545+
return TorchSampler(sampler_args)
546+
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
529547
):
530-
sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len,
531-
spec_config=model_engine.spec_config)
532-
elif pytorch_backend_config.enable_trtllm_sampler:
548+
return get_spec_decoder(sampler_args, engine.spec_config)
549+
if pytorch_backend_config.enable_trtllm_sampler:
533550
decoding_mode = get_decoding_mode(executor_config)
534-
sampler = TRTLLMSampler(
535-
executor_config, model_engine.model, model_engine.dtype, mapping,
536-
decoding_mode, pytorch_backend_config.disable_overlap_scheduler)
537-
elif not model_engine.model.model_config.is_generation:
551+
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
552+
mapping, decoding_mode,
553+
pytorch_backend_config.disable_overlap_scheduler)
554+
if not engine.model.model_config.is_generation:
538555
# NOTE: choose sampler based on model type
539-
sampler = EarlyStopSampler()
540-
else:
541-
sampler = TorchSampler(
542-
max_seq_len=model_engine.max_seq_len,
543-
mixed_sampler=pytorch_backend_config.mixed_sampler)
544-
return sampler
556+
return EarlyStopSampler()
557+
return TorchSampler(sampler_args)
545558

546559

547560
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import math
32
from typing import List, Optional
43

@@ -52,8 +51,7 @@ def bitmask_size(self) -> int:
5251

5352
def build(self, scheduled_requests: ScheduledRequests,
5453
resource_manager: SeqSlotManager) -> None:
55-
for llm_req in itertools.chain(scheduled_requests.context_requests,
56-
scheduled_requests.generation_requests):
54+
for llm_req in scheduled_requests.all_requests():
5755
if llm_req.guided_decoding_params is None:
5856
continue
5957
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
@@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
8482
torch.cuda.current_stream().wait_stream(self._stream)
8583

8684
batched_logits, batched_bitmask = [], []
87-
for i, llm_req in enumerate(
88-
itertools.chain(scheduled_requests.context_requests,
89-
scheduled_requests.generation_requests)):
85+
for i, llm_req in enumerate(scheduled_requests.all_requests()):
9086
if llm_req.guided_decoding_params is None:
9187
continue
9288
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
return_logits_device_memory: bool = True,
254254
exclude_last_generation_logits: bool = False,
255255
stop_words_list: list[list[int]] | None = None,
256+
is_draft: bool = False,
256257
**kwargs):
257258
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
258259
None)
@@ -288,6 +289,7 @@ def __init__(
288289
self.py_return_context_logits = return_context_logits
289290
self.py_return_generation_logits = return_generation_logits
290291
self.py_return_logits_device_memory = return_logits_device_memory
292+
self.py_is_draft = is_draft
291293

292294
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
293295
# currently, keep py_stop_words_list as python list, rather than tensor.

0 commit comments

Comments
 (0)