Skip to content

Commit c3a3a01

Browse files
govind-ramnarayan2ez4bzlucaslie
authored andcommitted
[NVIDIA#8245][feat] Autodeploy: Guided Decoding Support (NVIDIA#8551)
Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Signed-off-by: Govind Ramnarayan <105831528+govind-ramnarayan@users.noreply.github.com> Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Co-authored-by: William Zhang <133824995+2ez4bz@users.noreply.github.com> Co-authored-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent ae4a862 commit c3a3a01

File tree

10 files changed

+412
-16
lines changed

10 files changed

+412
-16
lines changed

examples/auto_deploy/build_and_run_ad.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from tensorrt_llm._torch.auto_deploy import LLM, AutoDeployConfig, DemoLLM
18+
from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs
1819
from tensorrt_llm._torch.auto_deploy.utils._config import (
1920
DynamicYamlMixInForSettings,
2021
deep_merge_dicts,
@@ -139,9 +140,10 @@ class ExperimentConfig(DynamicYamlMixInForSettings, BaseSettings):
139140

140141
### CORE ARGS ##################################################################################
141142
# The main AutoDeploy arguments - contains model, tokenizer, backend configs, etc.
142-
args: AutoDeployConfig = Field(
143+
args: LlmArgs = Field(
143144
description="The main AutoDeploy arguments containing model, tokenizer, backend configs, etc. "
144-
"Please check `tensorrt_llm._torch.auto_deploy.llm_args.AutoDeployConfig` for more details."
145+
"Contains all the fields from `AutoDeployConfig` and `BaseLlmArgs`. "
146+
"Please check `tensorrt_llm._torch.auto_deploy.llm_args.LlmArgs` for more details."
145147
)
146148

147149
# Optional model field for convenience - if provided, will be used to initialize args.model
@@ -304,6 +306,7 @@ def main(config: Optional[ExperimentConfig] = None):
304306
store_benchmark_results(results, config.benchmark.results_path)
305307

306308
llm.shutdown()
309+
return results
307310

308311

309312
if __name__ == "__main__":

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(
8787
max_batch_size: int = 1,
8888
page_size: int = 0,
8989
max_num_tokens: Optional[int] = None,
90+
vocab_size_padded: Optional[int] = None,
9091
):
9192
"""Initialize the SequenceInfo object.
9293
@@ -104,14 +105,15 @@ def __init__(
104105
batch is min (max_batch_size, max_num_tokens // ISL). Similarly, if a batch is
105106
composed of generate-only requests, then the maximum number of sequences possible in
106107
the batch is min (max_batch_size, max_num_tokens).
107-
108+
vocab_size_padded: corresponds to the padded vocabulary size of the model.
108109
Returns:
109110
None
110111
"""
111112
# set up basic attributes
112113
self.max_seq_len = max_seq_len
113114
self.max_batch_size = max_batch_size
114115
self.page_size = page_size if page_size > 0 else max_seq_len
116+
self.vocab_size_padded = vocab_size_padded
115117

116118
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
117119
# (max_batch_size, max_seq_len) input in trtllm runtime.

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,18 @@ def tokenizer(self) -> Optional[str]:
131131
"""The tokenizer path."""
132132
return self._prefetched_tokenizer_path or self._tokenizer or self.model
133133

134+
@property
135+
def vocab_size_padded(self) -> Optional[int]:
136+
"""Return the padded vocabulary size of the model.
137+
138+
This is needed for guided decoding in the pyexecutor. If the factory does not support this,
139+
then this method should return None.
140+
141+
Returns:
142+
The padded vocabulary size of the model.
143+
"""
144+
return None
145+
134146
def build_model(self, device: str) -> nn.Module:
135147
"""Build the model on the desired device.
136148
@@ -164,10 +176,7 @@ def forward(
164176
the factory.
165177
"""
166178
# make sure model architecture is pre-fetched (no weights needed at this point)
167-
skip_loading_weights = self.skip_loading_weights
168-
self.skip_loading_weights = True
169-
self.prefetch_checkpoint()
170-
self.skip_loading_weights = skip_loading_weights
179+
self.prefetch_checkpoint(skip_loading_weights=True)
171180

172181
# build the model
173182
return self._build_model(device)
@@ -211,15 +220,18 @@ def init_processor(self) -> Optional[Any]:
211220
"""
212221
return None
213222

214-
def prefetch_checkpoint(self, force: bool = False):
223+
def prefetch_checkpoint(self, force: bool = False, skip_loading_weights: Optional[bool] = None):
215224
"""Try or skip prefetching the checkpoint for the model and tokenizer.
216225
217226
Args:
218227
force: Whether to force prefetching the checkpoint.
228+
skip_loading_weights: Whether to skip loading weights. If not provided, it will use
229+
the factory's skip_loading_weights value.
219230
"""
220231
if not self._prefetched_model_path or force:
221232
self._prefetched_model_path = self._prefetch_checkpoint(
222-
self._model, self.skip_loading_weights
233+
self._model,
234+
self.skip_loading_weights if skip_loading_weights is None else skip_loading_weights,
223235
)
224236
if self._tokenizer and (not self._prefetched_tokenizer_path or force):
225237
self._prefetched_tokenizer_path = self._prefetch_checkpoint(self._tokenizer, True)

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def __init__(self, *args, **kwargs):
119119
def automodel_cls(self) -> Type[_BaseAutoModelClass]:
120120
return AutoModelForCausalLM
121121

122+
@property
123+
def vocab_size_padded(self) -> Optional[int]:
124+
model_config, _ = self._get_model_config()
125+
return getattr(model_config, "vocab_size", None)
126+
122127
def _recursive_update_config(
123128
self, config: PretrainedConfig, update_dict: Dict[str, Any]
124129
) -> Tuple[PretrainedConfig, Dict[str, Any]]:
@@ -167,6 +172,9 @@ def _recursive_update_config(
167172
return config, nested_unused_kwargs
168173

169174
def _get_model_config(self) -> Tuple[PretrainedConfig, Dict[str, Any]]:
175+
# prefetch the model once without weights
176+
self.prefetch_checkpoint(skip_loading_weights=True)
177+
170178
# NOTE (lucaslie): HF doesn't recursively update nested PreTrainedConfig objects. Instead,
171179
# the entire subconfig will be overwritten.
172180
# we want to recursively update model_config from model_kwargs here.

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
112
from collections import defaultdict
213
from types import SimpleNamespace
314
from typing import Dict, List, Optional, Tuple
@@ -6,9 +17,12 @@
617
from strenum import StrEnum
718
from torch._prims_common import DeviceLikeType
819

20+
from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder
21+
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
922
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
1023
from tensorrt_llm._utils import nvtx_range
1124
from tensorrt_llm.llmapi.llm_args import ContextChunkingPolicy
25+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
1226

1327
from ...._utils import mpi_rank, mpi_world_size
1428
from ....bindings.internal.batch_manager import CacheType
@@ -26,7 +40,7 @@
2640
)
2741
from ..custom_ops.attention_interface import SequenceInfo
2842
from ..distributed import common as dist
29-
from ..llm_args import AutoDeployConfig, LlmArgs
43+
from ..llm_args import LlmArgs
3044
from ..transform.optimizer import InferenceOptimizer
3145
from ..utils.logger import ad_logger
3246
from .interface import CachedSequenceInterface, GetInferenceModel
@@ -83,8 +97,8 @@ def _device(self) -> DeviceLikeType:
8397
return self.cache_seq_interface.device
8498

8599
@classmethod
86-
def build_from_config(cls, ad_config: AutoDeployConfig):
87-
"""Build the ADEngine using the AutoDeployConfig that gets passed through from the LLM."""
100+
def build_from_config(cls, ad_config: LlmArgs):
101+
"""Build the ADEngine using the LlmArgs that gets passed through from the LLM."""
88102

89103
max_batch_size = ad_config.max_batch_size
90104
max_seq_len = ad_config.max_seq_len
@@ -98,16 +112,17 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
98112
device = torch.device(f"cuda:{torch.cuda.current_device()}")
99113
device = str(device)
100114

115+
factory = ad_config.create_factory()
116+
101117
# initialize seq info object
102118
seq_info = SequenceInfo(
103119
max_seq_len=max_seq_len,
104120
max_batch_size=max_batch_size,
105121
page_size=attn_page_size,
106122
max_num_tokens=max_num_tokens,
123+
vocab_size_padded=factory.vocab_size_padded,
107124
)
108125

109-
factory = ad_config.create_factory()
110-
111126
# TODO (lucaslie): consider how we move args around InferenceOptimizer.__init__,
112127
# ADEngine.__init__, and ADEngine.build_from_config. Seems a bit unnatural atm.
113128

@@ -296,8 +311,9 @@ def forward(
296311
return {"logits": logits_flat}
297312

298313

299-
def create_autodeploy_executor(ad_config: LlmArgs):
300-
"""Create an AutoDeploy executor from the given configuration and checkpoint directory.
314+
def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None):
315+
"""Create an AutoDeploy executor from the given configuration and tokenizer.
316+
The tokenizer is required for guided decoding.
301317
302318
This is the entrypoint API to the _autodeploy backend.
303319
"""
@@ -404,6 +420,25 @@ def create_autodeploy_executor(ad_config: LlmArgs):
404420
)
405421
sampler = TorchSampler(sampler_args)
406422

423+
# Guided (istructured) decoding.
424+
guided_decoder = None
425+
if (
426+
(guided_decoding_backend := ad_config.guided_decoding_backend) is not None
427+
) and dist_mapping.is_last_pp_rank():
428+
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
429+
if vocab_size_padded is None:
430+
raise RuntimeError(
431+
"Could not determine the vocabulary size. Required for guided decoding."
432+
)
433+
guided_decoding_config = get_guided_decoding_config(
434+
guided_decoding_backend=guided_decoding_backend, tokenizer=tokenizer
435+
)
436+
guided_decoder = GuidedDecoder(
437+
guided_decoding_config=guided_decoding_config,
438+
max_num_sequences=ad_config.max_batch_size,
439+
vocab_size_padded=vocab_size_padded,
440+
)
441+
407442
# creating the executor object
408443
py_executor = PyExecutor(
409444
resource_manager,
@@ -418,5 +453,6 @@ def create_autodeploy_executor(ad_config: LlmArgs):
418453
max_draft_len=max_draft_len,
419454
max_total_draft_tokens=max_total_draft_tokens,
420455
max_beam_width=ad_config.max_beam_width,
456+
guided_decoder=guided_decoder,
421457
)
422458
return py_executor

tensorrt_llm/executor/base_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _create_py_executor():
130130
create_executor = create_autodeploy_executor
131131
assert isinstance(self.llm_args, ADLlmArgs)
132132
args["ad_config"] = self.llm_args.get_pytorch_backend_config()
133+
args["tokenizer"] = self._tokenizer
133134
else:
134135
raise ValueError(
135136
f"Unsupported backend config: {self.llm_args.backend}")
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import os
18+
19+
from build_and_run_ad import ExperimentConfig, main
20+
from defs.conftest import llm_models_root
21+
22+
from tensorrt_llm.sampling_params import GuidedDecodingParams
23+
24+
25+
def test_autodeploy_guided_decoding_main_json():
26+
schema = (
27+
"{"
28+
'"title": "WirelessAccessPoint", "type": "object", "properties": {'
29+
'"ssid": {"title": "SSID", "type": "string"}, '
30+
'"securityProtocol": {"title": "SecurityProtocol", "type": "string"}, '
31+
'"bandwidth": {"title": "Bandwidth", "type": "string"}}, '
32+
'"required": ["ssid", "securityProtocol", "bandwidth"]}')
33+
34+
model_path = os.path.join(llm_models_root(),
35+
"llama-models-v2/TinyLlama-1.1B-Chat-v1.0")
36+
37+
print(f"model_path: {model_path}")
38+
llm_args = {
39+
"model": model_path,
40+
"guided_decoding_backend": "xgrammar",
41+
"skip_loading_weights": False,
42+
}
43+
44+
experiment_config = {
45+
"args": llm_args,
46+
"benchmark": {
47+
"enabled": False
48+
},
49+
"prompt": {
50+
"batch_size":
51+
1,
52+
"queries":
53+
("Please provide a JSON object representing a wireless access point. "
54+
"Follow this exact schema: " + schema),
55+
},
56+
}
57+
58+
# DemoLLM runtime does not support guided decoding. Need to set runtime to trtllm.
59+
experiment_config["args"]["runtime"] = "trtllm"
60+
experiment_config["args"]["world_size"] = 1
61+
62+
cfg = ExperimentConfig(**experiment_config)
63+
64+
# Need to introduce the guided decoding params after ExperimentConfig construction
65+
# because otherwise they get unpacked as a dict.
66+
cfg.prompt.sp_kwargs = {
67+
"max_tokens": 100,
68+
"top_k": None,
69+
"temperature": 0.1,
70+
"guided_decoding": GuidedDecodingParams(json=schema),
71+
}
72+
73+
result = main(cfg)
74+
print(f"guided_text: {result}")
75+
76+
# Extract the generated text from the nested structure
77+
# Format: {'prompts_and_outputs': [[prompt, output]]}
78+
assert "prompts_and_outputs" in result, "Result should contain 'prompts_and_outputs'"
79+
assert len(result["prompts_and_outputs"]
80+
) > 0, "Should have at least one prompt/output pair"
81+
82+
_prompt, generated_text = result["prompts_and_outputs"][0]
83+
print(f"Generated text: {generated_text}")
84+
85+
# Parse and validate the JSON
86+
try:
87+
guided_json = json.loads(generated_text)
88+
except Exception as e:
89+
print(
90+
f"Failed to parse generated text as JSON. Raw text: {generated_text!r}"
91+
)
92+
raise AssertionError(f"Generated text is not valid JSON: {e}") from e
93+
94+
# Assert the JSON conforms to the schema
95+
assert "ssid" in guided_json, "JSON must contain 'ssid' field"
96+
assert "securityProtocol" in guided_json, "JSON must contain 'securityProtocol' field"
97+
assert "bandwidth" in guided_json, "JSON must contain 'bandwidth' field"
98+
99+
# Validate field types
100+
assert isinstance(guided_json["ssid"], str), "'ssid' must be a string"
101+
assert isinstance(guided_json["securityProtocol"],
102+
str), "'securityProtocol' must be a string"
103+
assert isinstance(guided_json["bandwidth"],
104+
str), "'bandwidth' must be a string"
105+
106+
# Validate non-empty values
107+
assert len(guided_json["ssid"]) > 0, "'ssid' must not be empty"
108+
assert len(guided_json["securityProtocol"]
109+
) > 0, "'securityProtocol' must not be empty"
110+
assert len(guided_json["bandwidth"]) > 0, "'bandwidth' must not be empty"
111+
112+
print(f"Validation passed! Generated JSON: {guided_json}")

tests/integration/defs/pytest.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ threadleak = True
44
threadleak_exclude = asyncio_\d+
55
junit_family=legacy
66
addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning
7+
pythonpath =
8+
../../../examples/auto_deploy
79
norecursedirs = ./triton/perf
810
markers =
911
skip_less_device: skip when less device detected than the declared

0 commit comments

Comments
 (0)