|
13 | 13 | from typing_extensions import Annotated, Required, TypedDict |
14 | 14 |
|
15 | 15 | from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams |
16 | | -from tensorrt_llm.llmapi import SamplingParams |
| 16 | +from tensorrt_llm.llmapi import GuidedDecodingParams, SamplingParams |
17 | 17 |
|
18 | 18 |
|
19 | 19 | class OpenAIBaseModel(BaseModel): |
@@ -44,9 +44,17 @@ class ModelList(OpenAIBaseModel): |
44 | 44 | data: List[ModelCard] = Field(default_factory=list) |
45 | 45 |
|
46 | 46 |
|
| 47 | +class StructuralTag(OpenAIBaseModel): |
| 48 | + begin: str |
| 49 | + schema_: Optional[dict[str, Any]] = Field(alias="schema") |
| 50 | + end: str |
| 51 | + |
| 52 | + |
47 | 53 | class ResponseFormat(OpenAIBaseModel): |
48 | | - # type must be "json_object" or "text" |
49 | | - type: Literal["text", "json_object"] |
| 54 | + # type must be "json_object" or "text" or "structural_tag" |
| 55 | + type: Literal["text", "json_object", "structural_tag"] |
| 56 | + structures: Optional[List[StructuralTag]] = None |
| 57 | + triggers: Optional[List[str]] = None |
50 | 58 |
|
51 | 59 |
|
52 | 60 | class DisaggregatedParams(OpenAIBaseModel): |
@@ -121,6 +129,23 @@ class CompletionStreamResponse(OpenAIBaseModel): |
121 | 129 | usage: Optional[UsageInfo] = Field(default=None) |
122 | 130 |
|
123 | 131 |
|
| 132 | +def _response_format_to_guided_decoding_params( |
| 133 | + response_format: Optional[ResponseFormat] |
| 134 | +) -> Optional[GuidedDecodingParams]: |
| 135 | + if response_format is None: |
| 136 | + return None |
| 137 | + elif response_format.type == "text": |
| 138 | + return None |
| 139 | + elif response_format.type == "json_object": |
| 140 | + return GuidedDecodingParams(json_object=True) |
| 141 | + elif response_format.type == "structural_tag": |
| 142 | + return GuidedDecodingParams( |
| 143 | + structural_tag=response_format.model_dump_json(by_alias=True, |
| 144 | + exclude_none=True)) |
| 145 | + else: |
| 146 | + raise ValueError(f"Unsupported response format: {response_format.type}") |
| 147 | + |
| 148 | + |
124 | 149 | class CompletionRequest(OpenAIBaseModel): |
125 | 150 | # Ordered by official OpenAI API documentation |
126 | 151 | # https://platform.openai.com/docs/api-reference/completions/create |
@@ -170,10 +195,10 @@ class CompletionRequest(OpenAIBaseModel): |
170 | 195 | ) |
171 | 196 | response_format: Optional[ResponseFormat] = Field( |
172 | 197 | default=None, |
173 | | - description=( |
174 | | - "Similar to chat completion, this parameter specifies the format of " |
175 | | - "output. Only {'type': 'json_object'} or {'type': 'text' } is " |
176 | | - "supported."), |
| 198 | + description= |
| 199 | + ("Similar to chat completion, this parameter specifies the format of " |
| 200 | + "output. {'type': 'json_object'}, {'type': 'text' }, {'type': 'structural_tag'} are " |
| 201 | + "supported."), |
177 | 202 | ) |
178 | 203 |
|
179 | 204 | disaggregated_params: Optional[DisaggregatedParams] = Field( |
@@ -211,6 +236,8 @@ def to_sampling_params(self) -> SamplingParams: |
211 | 236 | spaces_between_special_tokens=self.spaces_between_special_tokens, |
212 | 237 | truncate_prompt_tokens=self.truncate_prompt_tokens, |
213 | 238 | return_context_logits=self.return_context_logits, |
| 239 | + guided_decoding=_response_format_to_guided_decoding_params( |
| 240 | + self.response_format), |
214 | 241 |
|
215 | 242 | # completion-extra-params |
216 | 243 | add_special_tokens=self.add_special_tokens, |
@@ -255,13 +282,6 @@ def verify_multi_responses(cls, data): |
255 | 282 | raise ValueError("best_of should not be smaller than n") |
256 | 283 | return data |
257 | 284 |
|
258 | | - @model_validator(mode="before") |
259 | | - @classmethod |
260 | | - def check_response_format(cls, data): |
261 | | - if data.get("response_format"): |
262 | | - raise ValueError("response_format is not supported") |
263 | | - return data |
264 | | - |
265 | 285 | @model_validator(mode="before") |
266 | 286 | @classmethod |
267 | 287 | def check_suffix(cls, data): |
@@ -520,6 +540,8 @@ def to_sampling_params(self) -> SamplingParams: |
520 | 540 | skip_special_tokens=self.skip_special_tokens, |
521 | 541 | spaces_between_special_tokens=self.spaces_between_special_tokens, |
522 | 542 | truncate_prompt_tokens=self.truncate_prompt_tokens, |
| 543 | + guided_decoding=_response_format_to_guided_decoding_params( |
| 544 | + self.response_format), |
523 | 545 |
|
524 | 546 | # chat-completion-extra-params |
525 | 547 | add_special_tokens=self.add_special_tokens, |
@@ -582,13 +604,6 @@ def verify_logit_processor(cls, data): |
582 | 604 | raise ValueError("logit bias is not supported") |
583 | 605 | return data |
584 | 606 |
|
585 | | - @model_validator(mode="before") |
586 | | - @classmethod |
587 | | - def check_response_format(cls, data): |
588 | | - if data.get("response_format"): |
589 | | - raise ValueError("response_format is not supported") |
590 | | - return data |
591 | | - |
592 | 607 | @model_validator(mode="before") |
593 | 608 | @classmethod |
594 | 609 | def check_suffix(cls, data): |
|
0 commit comments