1+ # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ # Licensed under the Apache License, Version 2.0 (the "License");
3+ # you may not use this file except in compliance with the License.
4+ # You may obtain a copy of the License at
5+ # http://www.apache.org/licenses/LICENSE-2.0
6+ # Unless required by applicable law or agreed to in writing, software
7+ # distributed under the License is distributed on an "AS IS" BASIS,
8+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+ # See the License for the specific language governing permissions and
10+ # limitations under the License.
11+
112from collections import defaultdict
213from types import SimpleNamespace
314from typing import Dict , List , Optional , Tuple
617from strenum import StrEnum
718from torch ._prims_common import DeviceLikeType
819
20+ from tensorrt_llm ._torch .pyexecutor .guided_decoder import GuidedDecoder
21+ from tensorrt_llm ._torch .pyexecutor .py_executor_creator import get_guided_decoding_config
922from tensorrt_llm ._torch .pyexecutor .seq_slot_manager import SeqSlotManager
1023from tensorrt_llm ._utils import nvtx_range
1124from tensorrt_llm .llmapi .llm_args import ContextChunkingPolicy
25+ from tensorrt_llm .llmapi .tokenizer import TokenizerBase
1226
1327from ...._utils import mpi_rank , mpi_world_size
1428from ....bindings .internal .batch_manager import CacheType
2640)
2741from ..custom_ops .attention_interface import SequenceInfo
2842from ..distributed import common as dist
29- from ..llm_args import AutoDeployConfig , LlmArgs
43+ from ..llm_args import LlmArgs
3044from ..transform .optimizer import InferenceOptimizer
3145from ..utils .logger import ad_logger
3246from .interface import CachedSequenceInterface , GetInferenceModel
@@ -83,8 +97,8 @@ def _device(self) -> DeviceLikeType:
8397 return self .cache_seq_interface .device
8498
8599 @classmethod
86- def build_from_config (cls , ad_config : AutoDeployConfig ):
87- """Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
100+ def build_from_config (cls , ad_config : LlmArgs ):
101+ """Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
88102
89103 max_batch_size = ad_config .max_batch_size
90104 max_seq_len = ad_config .max_seq_len
@@ -98,16 +112,17 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
98112 device = torch .device (f"cuda:{ torch .cuda .current_device ()} " )
99113 device = str (device )
100114
115+ factory = ad_config .create_factory ()
116+
101117 # initialize seq info object
102118 seq_info = SequenceInfo (
103119 max_seq_len = max_seq_len ,
104120 max_batch_size = max_batch_size ,
105121 page_size = attn_page_size ,
106122 max_num_tokens = max_num_tokens ,
123+ vocab_size_padded = factory .vocab_size_padded ,
107124 )
108125
109- factory = ad_config .create_factory ()
110-
111126 # TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
112127 # ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
113128
@@ -296,8 +311,9 @@ def forward(
296311 return {"logits" : logits_flat }
297312
298313
299- def create_autodeploy_executor (ad_config : LlmArgs ):
300- """Create an AutoDeploy executor from the given configuration and checkpoint directory.
314+ def create_autodeploy_executor (ad_config : LlmArgs , tokenizer : Optional [TokenizerBase ] = None ):
315+ """Create an AutoDeploy executor from the given configuration and tokenizer.
316+ The tokenizer is required for guided decoding.
301317
302318 This is the entrypoint API to the _autodeploy backend.
303319 """
@@ -404,6 +420,25 @@ def create_autodeploy_executor(ad_config: LlmArgs):
404420 )
405421 sampler = TorchSampler (sampler_args )
406422
423+ # Guided (istructured) decoding.
424+ guided_decoder = None
425+ if (
426+ (guided_decoding_backend := ad_config .guided_decoding_backend ) is not None
427+ ) and dist_mapping .is_last_pp_rank ():
428+ vocab_size_padded = engine .cache_seq_interface .info .vocab_size_padded
429+ if vocab_size_padded is None :
430+ raise RuntimeError (
431+ "Could not determine the vocabulary size. Required for guided decoding."
432+ )
433+ guided_decoding_config = get_guided_decoding_config (
434+ guided_decoding_backend = guided_decoding_backend , tokenizer = tokenizer
435+ )
436+ guided_decoder = GuidedDecoder (
437+ guided_decoding_config = guided_decoding_config ,
438+ max_num_sequences = ad_config .max_batch_size ,
439+ vocab_size_padded = vocab_size_padded ,
440+ )
441+
407442 # creating the executor object
408443 py_executor = PyExecutor (
409444 resource_manager ,
@@ -418,5 +453,6 @@ def create_autodeploy_executor(ad_config: LlmArgs):
418453 max_draft_len = max_draft_len ,
419454 max_total_draft_tokens = max_total_draft_tokens ,
420455 max_beam_width = ad_config .max_beam_width ,
456+ guided_decoder = guided_decoder ,
421457 )
422458 return py_executor
0 commit comments