33import json
44import warnings
55from typing import (
6- TYPE_CHECKING , Any , AsyncIterator , Iterator , Optional , Union
6+ TYPE_CHECKING , Any , AsyncIterator , Iterator , List , Optional , Union
77)
88
99from outlines .inputs import Chat
1010from outlines .models .base import AsyncModel , Model , ModelTypeAdapter
1111from outlines .models .openai import OpenAITypeAdapter
12+ from outlines .outputs import Output , StreamingOutput
13+ from outlines .tools import ToolDef
1214from outlines .types .dsl import (
1315 CFG ,
1416 JsonSchema ,
@@ -44,7 +46,7 @@ def format_input(self, model_input: Union[Chat, list, str]) -> list:
4446 """
4547 return OpenAITypeAdapter ().format_input (model_input )
4648
47- def format_output_type (self , output_type : Optional [Any ] = None ) -> dict :
49+ def format_output_type (self , output_type : Optional [Any ]) -> dict :
4850 """Generate the structured output argument to pass to the client.
4951
5052 Parameters
@@ -78,6 +80,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
7880 else :
7981 return {"extra_body" : {"regex" : to_regex (term )}}
8082
83+ def format_tools (self , tools ):
84+ """Not available for SGLang."""
85+ if tools :
86+ raise NotImplementedError (
87+ "Tools are not available for SGLang."
88+ )
89+
8190
8291class SGLang (Model ):
8392 """Thin wrapper around the `openai.OpenAI` client used to communicate with
@@ -106,9 +115,10 @@ def __init__(self, client, model_name: Optional[str] = None):
106115 def generate (
107116 self ,
108117 model_input : Union [Chat , list , str ],
109- output_type : Optional [Any ] = None ,
118+ output_type : Optional [Any ],
119+ tools : Optional [List [ToolDef ]],
110120 ** inference_kwargs : Any ,
111- ) -> Union [ str , list [str ] ]:
121+ ) -> Output | list [Output ]:
112122 """Generate text using SGLang.
113123
114124 Parameters
@@ -119,15 +129,18 @@ def generate(
119129 The desired format of the response generated by the model. All
120130 output types available in Outlines are supported provided your
121131 server uses a structured generation backend that supports them.
132+ tools
133+ The tools to use for the generation.
122134 inference_kwargs
123135 Additional keyword arguments to pass to the client.
124136
125137 Returns
126138 -------
127- Union[str, list[str] ]
139+ Output | list[Output ]
128140 The text generated by the model.
129141
130142 """
143+ self .type_adapter .format_tools (tools )
131144 client_args = self ._build_client_args (
132145 model_input ,
133146 output_type ,
@@ -145,14 +158,15 @@ def generate(
145158 )
146159
147160 if len (messages ) == 1 :
148- return messages [0 ].content
161+ return Output ( content = messages [0 ].content )
149162 else :
150- return [message .content for message in messages ]
163+ return [Output ( content = message .content ) for message in messages ]
151164
152165 def generate_batch (
153166 self ,
154167 model_input ,
155- output_type = None ,
168+ output_type ,
169+ tools ,
156170 ** inference_kwargs ,
157171 ):
158172 raise NotImplementedError (
@@ -162,9 +176,10 @@ def generate_batch(
162176 def generate_stream (
163177 self ,
164178 model_input : Union [Chat , list , str ],
165- output_type : Optional [Any ] = None ,
179+ output_type : Optional [Any ],
180+ tools : Optional [List [ToolDef ]],
166181 ** inference_kwargs : Any ,
167- ) -> Iterator [str ]:
182+ ) -> Iterator [StreamingOutput ]:
168183 """Stream text using SGLang.
169184
170185 Parameters
@@ -175,15 +190,18 @@ def generate_stream(
175190 The desired format of the response generated by the model. All
176191 output types available in Outlines are supported provided your
177192 server uses a structured generation backend that supports them.
193+ tools
194+ The tools to use for the generation.
178195 inference_kwargs
179196 Additional keyword arguments to pass to the client.
180197
181198 Returns
182199 -------
183- Iterator[str ]
200+ Iterator[StreamingOutput ]
184201 An iterator that yields the text generated by the model.
185202
186203 """
204+ self .type_adapter .format_tools (tools )
187205 client_args = self ._build_client_args (
188206 model_input , output_type , ** inference_kwargs ,
189207 )
@@ -194,12 +212,12 @@ def generate_stream(
194212
195213 for chunk in stream : # pragma: no cover
196214 if chunk .choices and chunk .choices [0 ].delta .content is not None :
197- yield chunk .choices [0 ].delta .content
215+ yield StreamingOutput ( content = chunk .choices [0 ].delta .content )
198216
199217 def _build_client_args (
200218 self ,
201219 model_input : Union [Chat , str , list ],
202- output_type : Optional [Any ] = None ,
220+ output_type : Optional [Any ],
203221 ** inference_kwargs : Any ,
204222 ) -> dict :
205223 """Build the arguments to pass to the SGLang client."""
@@ -250,9 +268,10 @@ def __init__(self, client, model_name: Optional[str] = None):
250268 async def generate (
251269 self ,
252270 model_input : Union [Chat , str , list ],
253- output_type : Optional [Any ] = None ,
271+ output_type : Optional [Any ],
272+ tools : Optional [List [ToolDef ]],
254273 ** inference_kwargs : Any ,
255- ) -> Union [str , list [str ]]:
274+ ) -> Union [Output , list [Output ]]:
256275 """Generate text using `sglang`.
257276
258277 Parameters
@@ -263,15 +282,18 @@ async def generate(
263282 The desired format of the response generated by the model. All
264283 output types available in Outlines are supported provided your
265284 server uses a structured generation backend that supports them.
285+ tools
286+ The tools to use for the generation.
266287 inference_kwargs
267288 Additional keyword arguments to pass to the client.
268289
269290 Returns
270291 -------
271- Union[str , list[str ]]
292+ Union[Output , list[Output ]]
272293 The text generated by the model.
273294
274295 """
296+ self .type_adapter .format_tools (tools )
275297 client_args = self ._build_client_args (
276298 model_input , output_type , ** inference_kwargs ,
277299 )
@@ -287,14 +309,15 @@ async def generate(
287309 )
288310
289311 if len (messages ) == 1 :
290- return messages [0 ].content
312+ return Output ( content = messages [0 ].content )
291313 else :
292- return [message .content for message in messages ]
314+ return [Output ( content = message .content ) for message in messages ]
293315
294316 async def generate_batch (
295317 self ,
296318 model_input ,
297- output_type = None ,
319+ output_type ,
320+ tools ,
298321 ** inference_kwargs ,
299322 ):
300323 raise NotImplementedError (
@@ -304,9 +327,10 @@ async def generate_batch(
304327 async def generate_stream ( # type: ignore
305328 self ,
306329 model_input : Union [Chat , str , list ],
307- output_type : Optional [Any ] = None ,
330+ output_type : Optional [Any ],
331+ tools : Optional [List [ToolDef ]],
308332 ** inference_kwargs : Any ,
309- ) -> AsyncIterator [str ]:
333+ ) -> AsyncIterator [StreamingOutput ]:
310334 """Return a text generator.
311335
312336 Parameters
@@ -317,15 +341,18 @@ async def generate_stream( # type: ignore
317341 The desired format of the response generated by the model. All
318342 output types available in Outlines are supported provided your
319343 server uses a structured generation backend that supports them.
344+ tools
345+ The tools to use for the generation.
320346 inference_kwargs
321347 Additional keyword arguments to pass to the client.
322348
323349 Returns
324350 -------
325- AsyncIterator[str ]
351+ AsyncIterator[StreamingOutput ]
326352 An async iterator that yields the text generated by the model.
327353
328354 """
355+ self .type_adapter .format_tools (tools )
329356 client_args = self ._build_client_args (
330357 model_input , output_type , ** inference_kwargs ,
331358 )
@@ -337,12 +364,12 @@ async def generate_stream( # type: ignore
337364
338365 async for chunk in stream : # pragma: no cover
339366 if chunk .choices and chunk .choices [0 ].delta .content is not None :
340- yield chunk .choices [0 ].delta .content
367+ yield StreamingOutput ( content = chunk .choices [0 ].delta .content )
341368
342369 def _build_client_args (
343370 self ,
344371 model_input : Union [Chat , str , list ],
345- output_type : Optional [Any ] = None ,
372+ output_type : Optional [Any ],
346373 ** inference_kwargs : Any ,
347374 ) -> dict :
348375 """Build the arguments to pass to the SGLang client."""
0 commit comments