Skip to content

Commit c90ebad

Browse files
Ubospicasyuoni
andauthored
feat: Support the Structural Tag in guided decoding (NVIDIA#4066)
* finish Signed-off-by: Ubospica <[email protected]> * update Signed-off-by: Ubospica <[email protected]> * update Signed-off-by: Ubospica <[email protected]> * fix Signed-off-by: Enwei Zhu <[email protected]> * exc overlap scheduler Signed-off-by: Enwei Zhu <[email protected]> * add test Signed-off-by: Enwei Zhu <[email protected]> * fix api ref Signed-off-by: Enwei Zhu <[email protected]> --------- Signed-off-by: Ubospica <[email protected]> Signed-off-by: Enwei Zhu <[email protected]> Co-authored-by: Enwei Zhu <[email protected]>
1 parent 3e9bda3 commit c90ebad

File tree

10 files changed

+268
-23
lines changed

10 files changed

+268
-23
lines changed

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ class GuidedDecodingParams
483483
/// @brief The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar.
484484
/// EBNF grammar is widely-used to express context-free grammars.
485485
kEBNF_GRAMMAR = 3,
486+
487+
/// @brief The generated text is amenable to the XGrammar structural tag.
488+
kSTRUCTURAL_TAG = 4,
486489
};
487490

488491
explicit GuidedDecodingParams(GuideType guideType, std::optional<std::string> guide = std::nullopt);

cpp/tensorrt_llm/pybind/executor/request.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,8 @@ void initRequestBindings(pybind11::module_& m)
462462
.value("JSON", tle::GuidedDecodingParams::GuideType::kJSON)
463463
.value("JSON_SCHEMA", tle::GuidedDecodingParams::GuideType::kJSON_SCHEMA)
464464
.value("REGEX", tle::GuidedDecodingParams::GuideType::kREGEX)
465-
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR);
465+
.value("EBNF_GRAMMAR", tle::GuidedDecodingParams::GuideType::kEBNF_GRAMMAR)
466+
.value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG);
466467

467468
auto guidedDecodingParamsGetstate
468469
= [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); };

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,9 @@ def create_py_executor_instance(dist,
339339
if spec_config is not None:
340340
raise ValueError(
341341
"Guided decoding is not supported with speculative decoding.")
342+
if pytorch_backend_config.enable_overlap_scheduler:
343+
raise ValueError(
344+
"Guided decoding is not supported with overlap scheduler.")
342345

343346
logger.info(
344347
f"max_seq_len={executor_config.max_seq_len}, max_num_requests={executor_config.max_batch_size}, max_num_tokens={executor_config.max_num_tokens}"

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import json
23
import math
34
from typing import List, Optional
45

@@ -82,6 +83,18 @@ def build(self, scheduled_requests: ScheduledRequests,
8283
grammar = xgrammar.Grammar.from_ebnf(guide)
8384
compiled_grammar = self.xgrammar_compiler.compile_grammar(
8485
grammar)
86+
case GuidedDecodingParams.GuideType.STRUCTURAL_TAG:
87+
structural_tag_parameters = json.loads(guide)
88+
structures = structural_tag_parameters["structures"]
89+
structures = [
90+
xgrammar.StructuralTagItem(
91+
begin=s["begin"],
92+
schema=json.dumps(s["schema"]),
93+
end=s["end"]) for s in structures
94+
]
95+
triggers = structural_tag_parameters["triggers"]
96+
compiled_grammar = self.xgrammar_compiler.compile_structural_tag(
97+
structures, triggers)
8598
case _:
8699
raise ValueError(
87100
f"Unrecognized guide type: {guide_type}.")

tensorrt_llm/sampling_params.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@ class GuidedDecodingParams:
2020
regex (str, optional): The generated text is amenable to the user-specified regular expression. Defaults to None.
2121
grammar (str, optional): The generated text is amenable to the user-specified extended Backus-Naur form (EBNF) grammar. Defaults to None.
2222
json_object (bool): If True, the generated text is amenable to json format. Defaults to False.
23+
structural_tag (str, optional): The generated text is amenable to the user-specified structural tag. Defaults to None.
2324
"""
2425
json: Optional[Union[str, BaseModel, dict]] = None
2526
regex: Optional[str] = None
2627
grammar: Optional[str] = None
2728
json_object: bool = False
29+
structural_tag: Optional[str] = None
2830

2931
def _validate(self):
3032
num_guides = 0
@@ -451,7 +453,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
451453
tllme.GuidedDecodingParams.GuideType.JSON)
452454
elif self.guided_decoding.json is not None:
453455
json_schema = self.guided_decoding.json
454-
if isinstance(json, BaseModel):
456+
if isinstance(json_schema, BaseModel):
455457
json_schema = json_schema.model_json_schema()
456458
if isinstance(json_schema, dict):
457459
json_schema = json.dumps(json_schema)
@@ -465,5 +467,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams:
465467
return tllme.GuidedDecodingParams(
466468
tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR,
467469
self.guided_decoding.grammar)
470+
elif self.guided_decoding.structural_tag is not None:
471+
return tllme.GuidedDecodingParams(
472+
tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG,
473+
self.guided_decoding.structural_tag)
468474
else:
469475
return None

tensorrt_llm/serve/openai_protocol.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing_extensions import Annotated, Required, TypedDict
1414

1515
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
16-
from tensorrt_llm.llmapi import SamplingParams
16+
from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams
1717

1818

1919
class OpenAIBaseModel(BaseModel):
@@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel):
4444
data: List[ModelCard] = Field(default_factory=list)
4545

4646

47+
class StructuralTag(OpenAIBaseModel):
48+
begin: str
49+
schema_: Optional[dict[str, Any]] = Field(alias="schema")
50+
end: str
51+
52+
4753
class ResponseFormat(OpenAIBaseModel):
48-
# type must be "json_object" or "text"
49-
type: Literal["text", "json_object"]
54+
# type must be "json_object" or "text" or "structural_tag"
55+
type: Literal["text", "json_object", "structural_tag"]
56+
structures: Optional[List[StructuralTag]] = None
57+
triggers: Optional[List[str]] = None
5058

5159

5260
class DisaggregatedParams(OpenAIBaseModel):
@@ -121,6 +129,23 @@ class CompletionStreamResponse(OpenAIBaseModel):
121129
usage: Optional[UsageInfo] = Field(default=None)
122130

123131

132+
def _response_format_to_guided_decoding_params(
133+
response_format: Optional[ResponseFormat]
134+
) -> Optional[GuidedDecodingParams]:
135+
if response_format is None:
136+
return None
137+
elif response_format.type == "text":
138+
return None
139+
elif response_format.type == "json_object":
140+
return GuidedDecodingParams(json_object=True)
141+
elif response_format.type == "structural_tag":
142+
return GuidedDecodingParams(
143+
structural_tag=response_format.model_dump_json(by_alias=True,
144+
exclude_none=True))
145+
else:
146+
raise ValueError(f"Unsupported response format: {response_format.type}")
147+
148+
124149
class CompletionRequest(OpenAIBaseModel):
125150
# Ordered by official OpenAI API documentation
126151
# https://platform.openai.com/docs/api-reference/completions/create
@@ -170,10 +195,10 @@ class CompletionRequest(OpenAIBaseModel):
170195
)
171196
response_format: Optional[ResponseFormat] = Field(
172197
default=None,
173-
description=(
174-
"Similar to chat completion, this parameter specifies the format of "
175-
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
176-
"supported."),
198+
description=
199+
("Similar to chat completion, this parameter specifies the format of "
200+
"output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are "
201+
"supported."),
177202
)
178203

179204
disaggregated_params: Optional[DisaggregatedParams] = Field(
@@ -211,6 +236,8 @@ def to_sampling_params(self) -> SamplingParams:
211236
spaces_between_special_tokens=self.spaces_between_special_tokens,
212237
truncate_prompt_tokens=self.truncate_prompt_tokens,
213238
return_context_logits=self.return_context_logits,
239+
guided_decoding=_response_format_to_guided_decoding_params(
240+
self.response_format),
214241

215242
# completion-extra-params
216243
add_special_tokens=self.add_special_tokens,
@@ -255,13 +282,6 @@ def verify_multi_responses(cls, data):
255282
raise ValueError("best_of should not be smaller than n")
256283
return data
257284

258-
@model_validator(mode="before")
259-
@classmethod
260-
def check_response_format(cls, data):
261-
if data.get("response_format"):
262-
raise ValueError("response_format is not supported")
263-
return data
264-
265285
@model_validator(mode="before")
266286
@classmethod
267287
def check_suffix(cls, data):
@@ -520,6 +540,8 @@ def to_sampling_params(self) -> SamplingParams:
520540
skip_special_tokens=self.skip_special_tokens,
521541
spaces_between_special_tokens=self.spaces_between_special_tokens,
522542
truncate_prompt_tokens=self.truncate_prompt_tokens,
543+
guided_decoding=_response_format_to_guided_decoding_params(
544+
self.response_format),
523545

524546
# chat-completion-extra-params
525547
add_special_tokens=self.add_special_tokens,
@@ -582,13 +604,6 @@ def verify_logit_processor(cls, data):
582604
raise ValueError("logit bias is not supported")
583605
return data
584606

585-
@model_validator(mode="before")
586-
@classmethod
587-
def check_response_format(cls, data):
588-
if data.get("response_format"):
589-
raise ValueError("response_format is not supported")
590-
return data
591-
592607
@model_validator(mode="before")
593608
@classmethod
594609
def check_suffix(cls, data):

tests/integration/defs/test_e2e.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,15 @@ def test_openai_chat_multimodal_example(llm_root, llm_venv):
11211121
str(test_root / "_test_openai_chat_multimodal.py")])
11221122

11231123

1124+
def test_openai_chat_structural_tag_example(llm_venv):
1125+
test_root = unittest_path() / "llmapi" / "apps"
1126+
1127+
llm_venv.run_cmd([
1128+
"-m", "pytest",
1129+
str(test_root / "_test_openai_chat_structural_tag.py")
1130+
])
1131+
1132+
11241133
@pytest.mark.skip_less_device(2)
11251134
@pytest.mark.skip_less_device_memory(40000)
11261135
def test_openai_multi_chat_example(llm_root, llm_venv):

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ l0_a10:
2020
- disaggregated/test_disaggregated.py::test_disaggregated_overlap[TinyLlama-1.1B-Chat-v1.0]
2121
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-MAX_UTILIZATION-pytorch-stress-test]
2222
- stress_test/stress_test.py::test_run_stress_test[llama-v3-8b-instruct-hf_tp1-stress_time_300s_timeout_450s-GUARANTEED_NO_EVICT-pytorch-stress-test]
23+
- test_e2e.py::test_openai_chat_structural_tag_example
2324
- condition:
2425
ranges:
2526
system_gpu_count:

tests/unittest/api_stability/references/guided_decoding_params.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ methods:
1313
regex:
1414
annotation: Optional[str]
1515
default: null
16+
structural_tag:
17+
annotation: Optional[str]
18+
default: null
1619
return_annotation: None
1720
properties: {}

0 commit comments

Comments
 (0)