Skip to content

Commit d1331a4

Browse files
shen-shanshangooglercolin
authored andcommitted
[Structured Output] Make the output of structured output example more complete (vllm-project#22481)
Signed-off-by: shen-shanshan <[email protected]>
1 parent d1759cd commit d1331a4

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

examples/offline_inference/structured_outputs.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from vllm import LLM, SamplingParams
1616
from vllm.sampling_params import GuidedDecodingParams
1717

18+
MAX_TOKENS = 50
19+
1820
# Guided decoding by Choice (list of possible options)
1921
guided_decoding_params_choice = GuidedDecodingParams(choice=["Positive", "Negative"])
2022
sampling_params_choice = SamplingParams(guided_decoding=guided_decoding_params_choice)
@@ -23,7 +25,9 @@
2325
# Guided decoding by Regex
2426
guided_decoding_params_regex = GuidedDecodingParams(regex=r"\w+@\w+\.com\n")
2527
sampling_params_regex = SamplingParams(
26-
guided_decoding=guided_decoding_params_regex, stop=["\n"]
28+
guided_decoding=guided_decoding_params_regex,
29+
stop=["\n"],
30+
max_tokens=MAX_TOKENS,
2731
)
2832
prompt_regex = (
2933
"Generate an email address for Alan Turing, who works in Enigma."
@@ -48,7 +52,10 @@ class CarDescription(BaseModel):
4852

4953
json_schema = CarDescription.model_json_schema()
5054
guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
51-
sampling_params_json = SamplingParams(guided_decoding=guided_decoding_params_json)
55+
sampling_params_json = SamplingParams(
56+
guided_decoding=guided_decoding_params_json,
57+
max_tokens=MAX_TOKENS,
58+
)
5259
prompt_json = (
5360
"Generate a JSON with the brand, model and car_type of"
5461
"the most iconic car from the 90's"
@@ -64,7 +71,10 @@ class CarDescription(BaseModel):
6471
number ::= "1 " | "2 "
6572
"""
6673
guided_decoding_params_grammar = GuidedDecodingParams(grammar=simplified_sql_grammar)
67-
sampling_params_grammar = SamplingParams(guided_decoding=guided_decoding_params_grammar)
74+
sampling_params_grammar = SamplingParams(
75+
guided_decoding=guided_decoding_params_grammar,
76+
max_tokens=MAX_TOKENS,
77+
)
6878
prompt_grammar = (
6979
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
7080
)

0 commit comments

Comments
 (0)