Skip to content

Commit 39a474d

Browse files
committed
Implement tools and outputs for the SgLang model
1 parent 62acef9 commit 39a474d

File tree

3 files changed

+89
-48
lines changed

3 files changed

+89
-48
lines changed

outlines/models/sglang.py

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import json
44
import warnings
55
from typing import (
6-
TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union
6+
TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
77
)
88

99
from outlines.inputs import Chat
1010
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
1111
from outlines.models.openai import OpenAITypeAdapter
12+
from outlines.outputs import Output, StreamingOutput
13+
from outlines.tools import ToolDef
1214
from outlines.types.dsl import (
1315
CFG,
1416
JsonSchema,
@@ -44,7 +46,7 @@ def format_input(self, model_input: Union[Chat, list, str]) -> list:
4446
"""
4547
return OpenAITypeAdapter().format_input(model_input)
4648

47-
def format_output_type(self, output_type: Optional[Any] = None) -> dict:
49+
def format_output_type(self, output_type: Optional[Any]) -> dict:
4850
"""Generate the structured output argument to pass to the client.
4951
5052
Parameters
@@ -78,6 +80,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
7880
else:
7981
return {"extra_body": {"regex": to_regex(term)}}
8082

83+
def format_tools(self, tools):
84+
"""Not available for SGLang."""
85+
if tools:
86+
raise NotImplementedError(
87+
"Tools are not available for SGLang."
88+
)
89+
8190

8291
class SGLang(Model):
8392
"""Thin wrapper around the `openai.OpenAI` client used to communicate with
@@ -106,9 +115,10 @@ def __init__(self, client, model_name: Optional[str] = None):
106115
def generate(
107116
self,
108117
model_input: Union[Chat, list, str],
109-
output_type: Optional[Any] = None,
118+
output_type: Optional[Any],
119+
tools: Optional[List[ToolDef]],
110120
**inference_kwargs: Any,
111-
) -> Union[str, list[str]]:
121+
) -> Output | list[Output]:
112122
"""Generate text using SGLang.
113123
114124
Parameters
@@ -119,15 +129,18 @@ def generate(
119129
The desired format of the response generated by the model. All
120130
output types available in Outlines are supported provided your
121131
server uses a structured generation backend that supports them.
132+
tools
133+
The tools to use for the generation.
122134
inference_kwargs
123135
Additional keyword arguments to pass to the client.
124136
125137
Returns
126138
-------
127-
Union[str, list[str]]
139+
Output | list[Output]
128140
The text generated by the model.
129141
130142
"""
143+
self.type_adapter.format_tools(tools)
131144
client_args = self._build_client_args(
132145
model_input,
133146
output_type,
@@ -145,14 +158,15 @@ def generate(
145158
)
146159

147160
if len(messages) == 1:
148-
return messages[0].content
161+
return Output(content=messages[0].content)
149162
else:
150-
return [message.content for message in messages]
163+
return [Output(content=message.content) for message in messages]
151164

152165
def generate_batch(
153166
self,
154167
model_input,
155-
output_type = None,
168+
output_type,
169+
tools,
156170
**inference_kwargs,
157171
):
158172
raise NotImplementedError(
@@ -162,9 +176,10 @@ def generate_batch(
162176
def generate_stream(
163177
self,
164178
model_input: Union[Chat, list, str],
165-
output_type: Optional[Any] = None,
179+
output_type: Optional[Any],
180+
tools: Optional[List[ToolDef]],
166181
**inference_kwargs: Any,
167-
) -> Iterator[str]:
182+
) -> Iterator[StreamingOutput]:
168183
"""Stream text using SGLang.
169184
170185
Parameters
@@ -175,15 +190,18 @@ def generate_stream(
175190
The desired format of the response generated by the model. All
176191
output types available in Outlines are supported provided your
177192
server uses a structured generation backend that supports them.
193+
tools
194+
The tools to use for the generation.
178195
inference_kwargs
179196
Additional keyword arguments to pass to the client.
180197
181198
Returns
182199
-------
183-
Iterator[str]
200+
Iterator[StreamingOutput]
184201
An iterator that yields the text generated by the model.
185202
186203
"""
204+
self.type_adapter.format_tools(tools)
187205
client_args = self._build_client_args(
188206
model_input, output_type, **inference_kwargs,
189207
)
@@ -194,12 +212,12 @@ def generate_stream(
194212

195213
for chunk in stream: # pragma: no cover
196214
if chunk.choices and chunk.choices[0].delta.content is not None:
197-
yield chunk.choices[0].delta.content
215+
yield StreamingOutput(content=chunk.choices[0].delta.content)
198216

199217
def _build_client_args(
200218
self,
201219
model_input: Union[Chat, str, list],
202-
output_type: Optional[Any] = None,
220+
output_type: Optional[Any],
203221
**inference_kwargs: Any,
204222
) -> dict:
205223
"""Build the arguments to pass to the SGLang client."""
@@ -250,9 +268,10 @@ def __init__(self, client, model_name: Optional[str] = None):
250268
async def generate(
251269
self,
252270
model_input: Union[Chat, str, list],
253-
output_type: Optional[Any] = None,
271+
output_type: Optional[Any],
272+
tools: Optional[List[ToolDef]],
254273
**inference_kwargs: Any,
255-
) -> Union[str, list[str]]:
274+
) -> Union[Output, list[Output]]:
256275
"""Generate text using `sglang`.
257276
258277
Parameters
@@ -263,15 +282,18 @@ async def generate(
263282
The desired format of the response generated by the model. All
264283
output types available in Outlines are supported provided your
265284
server uses a structured generation backend that supports them.
285+
tools
286+
The tools to use for the generation.
266287
inference_kwargs
267288
Additional keyword arguments to pass to the client.
268289
269290
Returns
270291
-------
271-
Union[str, list[str]]
292+
Union[Output, list[Output]]
272293
The text generated by the model.
273294
274295
"""
296+
self.type_adapter.format_tools(tools)
275297
client_args = self._build_client_args(
276298
model_input, output_type, **inference_kwargs,
277299
)
@@ -287,14 +309,15 @@ async def generate(
287309
)
288310

289311
if len(messages) == 1:
290-
return messages[0].content
312+
return Output(content=messages[0].content)
291313
else:
292-
return [message.content for message in messages]
314+
return [Output(content=message.content) for message in messages]
293315

294316
async def generate_batch(
295317
self,
296318
model_input,
297-
output_type = None,
319+
output_type,
320+
tools,
298321
**inference_kwargs,
299322
):
300323
raise NotImplementedError(
@@ -304,9 +327,10 @@ async def generate_batch(
304327
async def generate_stream( # type: ignore
305328
self,
306329
model_input: Union[Chat, str, list],
307-
output_type: Optional[Any] = None,
330+
output_type: Optional[Any],
331+
tools: Optional[List[ToolDef]],
308332
**inference_kwargs: Any,
309-
) -> AsyncIterator[str]:
333+
) -> AsyncIterator[StreamingOutput]:
310334
"""Return a text generator.
311335
312336
Parameters
@@ -317,15 +341,18 @@ async def generate_stream( # type: ignore
317341
The desired format of the response generated by the model. All
318342
output types available in Outlines are supported provided your
319343
server uses a structured generation backend that supports them.
344+
tools
345+
The tools to use for the generation.
320346
inference_kwargs
321347
Additional keyword arguments to pass to the client.
322348
323349
Returns
324350
-------
325-
AsyncIterator[str]
351+
AsyncIterator[StreamingOutput]
326352
An async iterator that yields the text generated by the model.
327353
328354
"""
355+
self.type_adapter.format_tools(tools)
329356
client_args = self._build_client_args(
330357
model_input, output_type, **inference_kwargs,
331358
)
@@ -337,12 +364,12 @@ async def generate_stream( # type: ignore
337364

338365
async for chunk in stream: # pragma: no cover
339366
if chunk.choices and chunk.choices[0].delta.content is not None:
340-
yield chunk.choices[0].delta.content
367+
yield StreamingOutput(content=chunk.choices[0].delta.content)
341368

342369
def _build_client_args(
343370
self,
344371
model_input: Union[Chat, str, list],
345-
output_type: Optional[Any] = None,
372+
output_type: Optional[Any],
346373
**inference_kwargs: Any,
347374
) -> dict:
348375
"""Build the arguments to pass to the SGLang client."""

tests/models/test_sglang.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from outlines.inputs import Chat, Image
1717
from outlines.models.sglang import SGLang, AsyncSGLang, from_sglang
18+
from outlines.outputs import Output, StreamingOutput
1819
from outlines.types.dsl import CFG, Regex, JsonSchema
1920
from tests.test_utils.mock_openai_client import MockOpenAIClient, MockAsyncOpenAIClient
2021

@@ -231,7 +232,7 @@ def test_sglang_init():
231232

232233
def test_sglang_sync_simple_call(sync_model):
233234
result = sync_model("Respond with a single word.",)
234-
assert isinstance(result, str)
235+
assert isinstance(result, Output)
235236

236237

237238
def test_sglang_sync_streaming(sync_model_no_model_name):
@@ -240,7 +241,7 @@ def test_sglang_sync_streaming(sync_model_no_model_name):
240241
model=sglang_model_name,
241242
)
242243
assert isinstance(result, Generator)
243-
assert isinstance(next(result), str)
244+
assert isinstance(next(result), StreamingOutput)
244245

245246

246247
def test_sglang_sync_batch(sync_model):
@@ -252,7 +253,7 @@ def test_sglang_sync_batch(sync_model):
252253

253254
def test_sglang_sync_vision(sync_model):
254255
result = sync_model(["hello", image_input], max_tokens=10)
255-
assert isinstance(result, str)
256+
assert isinstance(result, Output)
256257

257258

258259
def test_sglang_sync_vision_chat(sync_model):
@@ -267,15 +268,15 @@ def test_sglang_sync_vision_chat(sync_model):
267268
]),
268269
max_tokens=10,
269270
)
270-
assert isinstance(result, str)
271+
assert isinstance(result, Output)
271272

272273

273274
def test_sglang_sync_multiple_samples(sync_model):
274275
result = sync_model("Respond with a single word.", n=2)
275276
assert isinstance(result, list)
276277
assert len(result) == 2
277-
assert isinstance(result[0], str)
278-
assert isinstance(result[1], str)
278+
assert isinstance(result[0], Output)
279+
assert isinstance(result[1], Output)
279280

280281

281282
def test_sglang_sync_json(sync_model):
@@ -284,14 +285,14 @@ def test_sglang_sync_json(sync_model):
284285
+ ' {"bar": {"type": "string"}}}'
285286
)
286287
result = sync_model("foo?", JsonSchema(json_string), max_tokens=10)
287-
assert isinstance(result, str)
288-
assert "bar" in result
288+
assert isinstance(result, Output)
289+
assert "bar" in result.content
289290

290291

291292
def test_sglang_sync_regex(sync_model):
292293
result = sync_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10)
293-
assert isinstance(result, str)
294-
assert re.match(r"[0-9]{3}", result)
294+
assert isinstance(result, Output)
295+
assert re.match(r"[0-9]{3}", result.content)
295296

296297

297298
def test_sglang_sync_cfg(sync_model):
@@ -300,14 +301,14 @@ def test_sglang_sync_cfg(sync_model):
300301
match="SGLang grammar-based structured outputs expects an EBNF"
301302
):
302303
result = sync_model("foo?", CFG(EBNF_YES_NO_GRAMMAR), max_tokens=10)
303-
assert isinstance(result, str)
304-
assert result in ["yes", "no"]
304+
assert isinstance(result, Output)
305+
assert result.content in ["yes", "no"]
305306

306307

307308
@pytest.mark.asyncio
308309
async def test_sglang_async_simple_call(async_model):
309310
result = await async_model("Respond with a single word.",)
310-
assert isinstance(result, str)
311+
assert isinstance(result, Output)
311312

312313

313314
@pytest.mark.asyncio
@@ -318,7 +319,7 @@ async def test_sglang_async_streaming(async_model_no_model_name):
318319
)
319320
assert isinstance(result, AsyncGenerator)
320321
async for chunk in result:
321-
assert isinstance(chunk, str)
322+
assert isinstance(chunk, StreamingOutput)
322323
break # Just check the first chunk
323324

324325

@@ -333,7 +334,7 @@ async def test_sglang_async_batch(async_model):
333334
@pytest.mark.asyncio
334335
async def test_sglang_async_vision(async_model):
335336
result = await async_model(["hello", image_input], max_tokens=10)
336-
assert isinstance(result, str)
337+
assert isinstance(result, Output)
337338

338339

339340
@pytest.mark.asyncio
@@ -349,16 +350,16 @@ async def test_sglang_async_vision_chat(async_model):
349350
]),
350351
max_tokens=10,
351352
)
352-
assert isinstance(result, str)
353+
assert isinstance(result, Output)
353354

354355

355356
@pytest.mark.asyncio
356357
async def test_sglang_async_multiple_samples(async_model):
357358
result = await async_model("Respond with a single word.", n=2)
358359
assert isinstance(result, list)
359360
assert len(result) == 2
360-
assert isinstance(result[0], str)
361-
assert isinstance(result[1], str)
361+
assert isinstance(result[0], Output)
362+
assert isinstance(result[1], Output)
362363

363364

364365
@pytest.mark.asyncio
@@ -368,19 +369,19 @@ async def test_sglang_async_json(async_model):
368369
+ ' {"bar": {"type": "string"}}}'
369370
)
370371
result = await async_model("foo?", JsonSchema(json_string), max_tokens=10)
371-
assert isinstance(result, str)
372-
assert "bar" in result
372+
assert isinstance(result, Output)
373+
assert "bar" in result.content
373374

374375

375376
@pytest.mark.asyncio
376377
async def test_sglang_async_regex(async_model):
377378
result = await async_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10)
378-
assert isinstance(result, str)
379-
assert re.match(r"[0-9]{3}", result)
379+
assert isinstance(result, Output)
380+
assert re.match(r"[0-9]{3}", result.content)
380381

381382

382383
@pytest.mark.asyncio
383384
async def test_sglang_async_cfg(async_model):
384385
result = await async_model("foo?", CFG(EBNF_YES_NO_GRAMMAR), max_tokens=10)
385-
assert isinstance(result, str)
386-
assert result in ["yes", "no"]
386+
assert isinstance(result, Output)
387+
assert result.content in ["yes", "no"]

0 commit comments

Comments
 (0)