11"""Integration with a vLLM server."""
22
33import json
4- from typing import TYPE_CHECKING , Any , AsyncIterator , Iterator , Optional , Union
4+ from typing import (
5+ TYPE_CHECKING ,
6+ Any ,
7+ AsyncIterator ,
8+ Iterator ,
9+ List ,
10+ Optional ,
11+ Union ,
12+ )
513
614from outlines .inputs import Chat
715from outlines .models .base import AsyncModel ,Model , ModelTypeAdapter
816from outlines .models .openai import OpenAITypeAdapter
17+ from outlines .outputs import Output , StreamingOutput
18+ from outlines .tools import ToolDef
919from outlines .types .dsl import CFG , JsonSchema , python_types_to_terms , to_regex
1020
1121if TYPE_CHECKING :
@@ -36,7 +46,7 @@ def format_input(self, model_input: Union[Chat, str, list]) -> list:
3646 """
3747 return OpenAITypeAdapter ().format_input (model_input )
3848
39- def format_output_type (self , output_type : Optional [Any ] = None ) -> dict :
49+ def format_output_type (self , output_type : Optional [Any ]) -> dict :
4050 """Generate the structured output argument to pass to the client.
4151
4252 Parameters
@@ -64,6 +74,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
6474 else :
6575 return {"guided_regex" : to_regex (term )}
6676
77+ def format_tools (self , tools ):
78+ """Not available for VLLM."""
79+ if tools :
80+ raise NotImplementedError (
81+ "Tools are not available for VLLM."
82+ )
83+
6784
6885class VLLM (Model ):
6986 """Thin wrapper around the `openai.OpenAI` client used to communicate with
@@ -93,9 +110,10 @@ def __init__(
93110 def generate (
94111 self ,
95112 model_input : Union [Chat , str , list ],
96- output_type : Optional [Any ] = None ,
113+ output_type : Optional [Any ],
114+ tools : Optional [List [ToolDef ]],
97115 ** inference_kwargs : Any ,
98- ) -> Union [str , list [str ]]:
116+ ) -> Union [Output , list [Output ]]:
99117 """Generate text using vLLM.
100118
101119 Parameters
@@ -106,15 +124,18 @@ def generate(
106124 The desired format of the response generated by the model. All
107125 output types available in Outlines are supported provided your
108126 server uses a structured generation backend that supports them.
127+ tools
128+ The tools to use for the generation.
109129 inference_kwargs
110130 Additional keyword arguments to pass to the client.
111131
112132 Returns
113133 -------
114- Union[str , list[str ]]
134+ Union[Output , list[Output ]]
115135 The text generated by the model.
116136
117137 """
138+ self .type_adapter .format_tools (tools )
118139 client_args = self ._build_client_args (
119140 model_input ,
120141 output_type ,
@@ -132,24 +153,26 @@ def generate(
132153 )
133154
134155 if len (messages ) == 1 :
135- return messages [0 ].content
156+ return Output ( content = messages [0 ].content )
136157 else :
137- return [message .content for message in messages ]
158+ return [Output ( content = message .content ) for message in messages ]
138159
139160 def generate_batch (
140161 self ,
141162 model_input ,
142- output_type = None ,
163+ output_type ,
164+ tools ,
143165 ** inference_kwargs ,
144166 ):
145167 raise NotImplementedError ("VLLM does not support batch inference." )
146168
147169 def generate_stream (
148170 self ,
149171 model_input : Union [Chat , str , list ],
150- output_type : Optional [Any ] = None ,
172+ output_type : Optional [Any ],
173+ tools : Optional [List [ToolDef ]],
151174 ** inference_kwargs : Any ,
152- ) -> Iterator [str ]:
175+ ) -> Iterator [StreamingOutput ]:
153176 """Stream text using vLLM.
154177
155178 Parameters
@@ -160,15 +183,18 @@ def generate_stream(
160183 The desired format of the response generated by the model. All
161184 output types available in Outlines are supported provided your
162185 server uses a structured generation backend that supports them.
186+ tools
187+ The tools to use for the generation.
163188 inference_kwargs
164189 Additional keyword arguments to pass to the client.
165190
166191 Returns
167192 -------
168- Iterator[str ]
193+ Iterator[StreamingOutput ]
169194 An iterator that yields the text generated by the model.
170195
171196 """
197+ self .type_adapter .format_tools (tools )
172198 client_args = self ._build_client_args (
173199 model_input , output_type , ** inference_kwargs ,
174200 )
@@ -179,12 +205,12 @@ def generate_stream(
179205
180206 for chunk in stream : # pragma: no cover
181207 if chunk .choices and chunk .choices [0 ].delta .content is not None :
182- yield chunk .choices [0 ].delta .content
208+ yield StreamingOutput ( content = chunk .choices [0 ].delta .content )
183209
184210 def _build_client_args (
185211 self ,
186212 model_input : Union [Chat , str , list ],
187- output_type : Optional [Any ] = None ,
213+ output_type : Optional [Any ],
188214 ** inference_kwargs : Any ,
189215 ) -> dict :
190216 """Build the arguments to pass to the OpenAI client."""
@@ -234,9 +260,10 @@ def __init__(
234260 async def generate (
235261 self ,
236262 model_input : Union [Chat , str , list ],
237- output_type : Optional [Any ] = None ,
263+ output_type : Optional [Any ],
264+ tools : Optional [List [ToolDef ]],
238265 ** inference_kwargs : Any ,
239- ) -> Union [str , list [str ]]:
266+ ) -> Union [Output , list [Output ]]:
240267 """Generate text using vLLM.
241268
242269 Parameters
@@ -247,12 +274,14 @@ async def generate(
247274 The desired format of the response generated by the model. All
248275 output types available in Outlines are supported provided your
249276 server uses a structured generation backend that supports them.
277+ tools
278+ The tools to use for the generation.
250279 inference_kwargs
251280 Additional keyword arguments to pass to the client.
252281
253282 Returns
254283 -------
255- Union[str , list[str ]]
284+ Union[Output , list[Output ]]
256285 The text generated by the model.
257286
258287 """
@@ -271,24 +300,26 @@ async def generate(
271300 )
272301
273302 if len (messages ) == 1 :
274- return messages [0 ].content
303+ return Output ( content = messages [0 ].content )
275304 else :
276- return [message .content for message in messages ]
305+ return [Output ( content = message .content ) for message in messages ]
277306
278307 async def generate_batch (
279308 self ,
280309 model_input ,
281- output_type = None ,
310+ output_type ,
311+ tools ,
282312 ** inference_kwargs ,
283313 ):
284314 raise NotImplementedError ("VLLM does not support batch inference." )
285315
286316 async def generate_stream ( # type: ignore
287317 self ,
288318 model_input : Union [Chat , str , list ],
289- output_type : Optional [Any ] = None ,
319+ output_type : Optional [Any ],
320+ tools : Optional [List [ToolDef ]],
290321 ** inference_kwargs : Any ,
291- ) -> AsyncIterator [str ]:
322+ ) -> AsyncIterator [StreamingOutput ]:
292323 """Stream text using vLLM.
293324
294325 Parameters
@@ -299,13 +330,16 @@ async def generate_stream( # type: ignore
299330 The desired format of the response generated by the model. All
300331 output types available in Outlines are supported provided your
301332 server uses a structured generation backend that supports them.
333+ tools
334+ The tools to use for the generation.
302335 inference_kwargs
303336 Additional keyword arguments to pass to the client.
304337
305338 Returns
306339 -------
307- AsyncIterator[str ]
340+ AsyncIterator[StreamingOutput ]
308341 An async iterator that yields the text generated by the model.
342+
309343 """
310344 client_args = self ._build_client_args (
311345 model_input , output_type , ** inference_kwargs ,
@@ -318,12 +352,12 @@ async def generate_stream( # type: ignore
318352
319353 async for chunk in stream : # pragma: no cover
320354 if chunk .choices and chunk .choices [0 ].delta .content is not None :
321- yield chunk .choices [0 ].delta .content
355+ yield StreamingOutput ( content = chunk .choices [0 ].delta .content )
322356
323357 def _build_client_args (
324358 self ,
325359 model_input : Union [Chat , str , list ],
326- output_type : Optional [Any ] = None ,
360+ output_type : Optional [Any ],
327361 ** inference_kwargs : Any ,
328362 ) -> dict :
329363 """Build the arguments to pass to the OpenAI client."""
0 commit comments