77from outlines .inputs import Chat
88from outlines .models .base import Model , ModelTypeAdapter
99from outlines .models .openai import OpenAITypeAdapter
10+ from outlines .outputs import Output
11+ from outlines .tools import ToolDef
1012from outlines .types .dsl import CFG , JsonSchema , python_types_to_terms , to_regex
1113
1214if TYPE_CHECKING :
@@ -56,7 +58,7 @@ def format_input_chat(self, model_input: Chat) -> list:
5658 )
5759 return OpenAITypeAdapter ().format_input (model_input )
5860
59- def format_output_type (self , output_type : Optional [Any ] = None ) -> dict :
61+ def format_output_type (self , output_type : Optional [Any ]) -> dict :
6062 """Generate the structured output argument to pass to the model.
6163
6264 For vLLM, the structured output definition is set in the
@@ -90,6 +92,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
9092 else :
9193 return {"regex" : to_regex (term )}
9294
95+ def format_tools (self , tools ):
96+ """Not available for VLLM offline."""
97+ if tools :
98+ raise NotImplementedError (
99+ "Tools are not available for VLLM offline."
100+ )
101+
93102
94103class VLLMOffline (Model ):
95104 """Thin wrapper around a `vllm.LLM` model.
@@ -114,7 +123,7 @@ def __init__(self, model: "LLM"):
114123 def _build_generation_args (
115124 self ,
116125 inference_kwargs : dict ,
117- output_type : Optional [Any ] = None ,
126+ output_type : Optional [Any ],
118127 ) -> "SamplingParams" :
119128 """Create the `SamplingParams` object to pass to the `generate` method
120129 of the `vllm.LLM` model."""
@@ -134,9 +143,10 @@ def _build_generation_args(
134143 def generate (
135144 self ,
136145 model_input : Chat | str ,
137- output_type : Optional [Any ] = None ,
146+ output_type : Optional [Any ],
147+ tools : Optional [List [ToolDef ]],
138148 ** inference_kwargs : Any ,
139- ) -> Union [str , List [str ]]:
149+ ) -> Union [Output , List [Output ]]:
140150 """Generate text using vLLM offline.
141151
142152 Parameters
@@ -146,16 +156,19 @@ def generate(
146156 output_type
147157 The logits processor the model will use to constrain the format of
148158 the generated text.
159+ tools
160+ The tools to use for the generation.
149161 inference_kwargs
150162 Additional keyword arguments to pass to the `generate` method
151163 in the `vllm.LLM` model.
152164
153165 Returns
154166 -------
155- Union[str , List[str ]]
167+ Union[Output , List[Output ]]
156168 The text generated by the model.
157169
158170 """
171+ self .type_adapter .format_tools (tools )
159172 sampling_params = self ._build_generation_args (
160173 inference_kwargs ,
161174 output_type ,
@@ -168,24 +181,25 @@ def generate(
168181 ** inference_kwargs ,
169182 )
170183 else :
171- results = self .model . generate (
184+ results = self .model (
172185 prompts = self .type_adapter .format_input (model_input ),
173186 sampling_params = sampling_params ,
174187 ** inference_kwargs ,
175188 )
176189 results = [completion .text for completion in results [0 ].outputs ]
177190
178191 if len (results ) == 1 :
179- return results [0 ]
192+ return Output ( content = results [0 ])
180193 else :
181- return results
194+ return [ Output ( content = result ) for result in results ]
182195
183196 def generate_batch (
184197 self ,
185198 model_input : List [Chat | str ],
186- output_type : Optional [Any ] = None ,
199+ output_type : Optional [Any ],
200+ tools : Optional [List [ToolDef ]],
187201 ** inference_kwargs : Any ,
188- ) -> Union [List [str ], List [List [str ]]]:
202+ ) -> Union [List [Output ], List [List [Output ]]]:
189203 """Generate a batch of completions using vLLM offline.
190204
191205 Parameters
@@ -196,16 +210,19 @@ def generate_batch(
196210 output_type
197211 The logits processor the model will use to constrain the format of
198212 the generated text.
213+ tools
214+ The tools to use for the generation.
199215 inference_kwargs
200216 Additional keyword arguments to pass to the `generate` method
201217 in the `vllm.LLM` model.
202218
203219 Returns
204220 -------
205- Union[List[str ], List[List[str ]]]
221+ Union[List[Output ], List[List[Output ]]]
206222 The text generated by the model.
207223
208224 """
225+ self .type_adapter .format_tools (tools )
209226 sampling_params = self ._build_generation_args (
210227 inference_kwargs ,
211228 output_type ,
@@ -216,14 +233,20 @@ def generate_batch(
216233 "Batch generation is not available for the `Chat` input type."
217234 )
218235
219- results = self .model . generate (
236+ results = self .model (
220237 prompts = [self .type_adapter .format_input (item ) for item in model_input ],
221238 sampling_params = sampling_params ,
222239 ** inference_kwargs ,
223240 )
224- return [[sample .text for sample in batch .outputs ] for batch in results ]
225241
226- def generate_stream (self , model_input , output_type , ** inference_kwargs ):
242+ return [ # type: ignore
243+ [Output (content = sample .text ) for sample in batch .outputs ]
244+ if len (batch .outputs ) > 1
245+ else Output (content = batch .outputs [0 ].text )
246+ for batch in results
247+ ]
248+
249+ def generate_stream (self , model_input , output_type , tools , ** inference_kwargs ):
227250 """Not available for `vllm.LLM`.
228251
229252 TODO: Implement the streaming functionality ourselves.
0 commit comments