15
15
from vllm import LLM , SamplingParams
16
16
from vllm .sampling_params import GuidedDecodingParams
17
17
18
+ MAX_TOKENS = 50
19
+
18
20
# Guided decoding by Choice (list of possible options)
19
21
guided_decoding_params_choice = GuidedDecodingParams (choice = ["Positive" , "Negative" ])
20
22
sampling_params_choice = SamplingParams (guided_decoding = guided_decoding_params_choice )
23
25
# Guided decoding by Regex
24
26
guided_decoding_params_regex = GuidedDecodingParams (regex = r"\w+@\w+\.com\n" )
25
27
sampling_params_regex = SamplingParams (
26
- guided_decoding = guided_decoding_params_regex , stop = ["\n " ]
28
+ guided_decoding = guided_decoding_params_regex ,
29
+ stop = ["\n " ],
30
+ max_tokens = MAX_TOKENS ,
27
31
)
28
32
prompt_regex = (
29
33
"Generate an email address for Alan Turing, who works in Enigma."
@@ -48,7 +52,10 @@ class CarDescription(BaseModel):
48
52
49
53
json_schema = CarDescription .model_json_schema ()
50
54
guided_decoding_params_json = GuidedDecodingParams (json = json_schema )
51
- sampling_params_json = SamplingParams (guided_decoding = guided_decoding_params_json )
55
+ sampling_params_json = SamplingParams (
56
+ guided_decoding = guided_decoding_params_json ,
57
+ max_tokens = MAX_TOKENS ,
58
+ )
52
59
prompt_json = (
53
60
"Generate a JSON with the brand, model and car_type of"
54
61
"the most iconic car from the 90's"
@@ -64,7 +71,10 @@ class CarDescription(BaseModel):
64
71
number ::= "1 " | "2 "
65
72
"""
66
73
guided_decoding_params_grammar = GuidedDecodingParams (grammar = simplified_sql_grammar )
67
- sampling_params_grammar = SamplingParams (guided_decoding = guided_decoding_params_grammar )
74
+ sampling_params_grammar = SamplingParams (
75
+ guided_decoding = guided_decoding_params_grammar ,
76
+ max_tokens = MAX_TOKENS ,
77
+ )
68
78
prompt_grammar = (
69
79
"Generate an SQL query to show the 'username' and 'email'from the 'users' table."
70
80
)
0 commit comments