Skip to content

Commit b3f5032

Browse files
committed
Implement tools and outputs for the VLLM model
1 parent 3e7f9a6 commit b3f5032

File tree

3 files changed

+96
-48
lines changed

3 files changed

+96
-48
lines changed

outlines/models/vllm.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
"""Integration with a vLLM server."""
22

33
import json
4-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union
4+
from typing import (
5+
TYPE_CHECKING,
6+
Any,
7+
AsyncIterator,
8+
Iterator,
9+
List,
10+
Optional,
11+
Union,
12+
)
513

614
from outlines.inputs import Chat
715
from outlines.models.base import AsyncModel,Model, ModelTypeAdapter
816
from outlines.models.openai import OpenAITypeAdapter
17+
from outlines.outputs import Output, StreamingOutput
18+
from outlines.tools import ToolDef
919
from outlines.types.dsl import CFG, JsonSchema, python_types_to_terms, to_regex
1020

1121
if TYPE_CHECKING:
@@ -36,7 +46,7 @@ def format_input(self, model_input: Union[Chat, str, list]) -> list:
3646
"""
3747
return OpenAITypeAdapter().format_input(model_input)
3848

39-
def format_output_type(self, output_type: Optional[Any] = None) -> dict:
49+
def format_output_type(self, output_type: Optional[Any]) -> dict:
4050
"""Generate the structured output argument to pass to the client.
4151
4252
Parameters
@@ -64,6 +74,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
6474
else:
6575
return {"guided_regex": to_regex(term)}
6676

77+
def format_tools(self, tools):
78+
"""Not available for VLLM."""
79+
if tools:
80+
raise NotImplementedError(
81+
"Tools are not available for VLLM."
82+
)
83+
6784

6885
class VLLM(Model):
6986
"""Thin wrapper around the `openai.OpenAI` client used to communicate with
@@ -93,9 +110,10 @@ def __init__(
93110
def generate(
94111
self,
95112
model_input: Union[Chat, str, list],
96-
output_type: Optional[Any] = None,
113+
output_type: Optional[Any],
114+
tools: Optional[List[ToolDef]],
97115
**inference_kwargs: Any,
98-
) -> Union[str, list[str]]:
116+
) -> Union[Output, list[Output]]:
99117
"""Generate text using vLLM.
100118
101119
Parameters
@@ -106,15 +124,18 @@ def generate(
106124
The desired format of the response generated by the model. All
107125
output types available in Outlines are supported provided your
108126
server uses a structured generation backend that supports them.
127+
tools
128+
The tools to use for the generation.
109129
inference_kwargs
110130
Additional keyword arguments to pass to the client.
111131
112132
Returns
113133
-------
114-
Union[str, list[str]]
134+
Union[Output, list[Output]]
115135
The text generated by the model.
116136
117137
"""
138+
self.type_adapter.format_tools(tools)
118139
client_args = self._build_client_args(
119140
model_input,
120141
output_type,
@@ -132,24 +153,26 @@ def generate(
132153
)
133154

134155
if len(messages) == 1:
135-
return messages[0].content
156+
return Output(content=messages[0].content)
136157
else:
137-
return [message.content for message in messages]
158+
return [Output(content=message.content) for message in messages]
138159

139160
def generate_batch(
140161
self,
141162
model_input,
142-
output_type = None,
163+
output_type,
164+
tools,
143165
**inference_kwargs,
144166
):
145167
raise NotImplementedError("VLLM does not support batch inference.")
146168

147169
def generate_stream(
148170
self,
149171
model_input: Union[Chat, str, list],
150-
output_type: Optional[Any] = None,
172+
output_type: Optional[Any],
173+
tools: Optional[List[ToolDef]],
151174
**inference_kwargs: Any,
152-
) -> Iterator[str]:
175+
) -> Iterator[StreamingOutput]:
153176
"""Stream text using vLLM.
154177
155178
Parameters
@@ -160,15 +183,18 @@ def generate_stream(
160183
The desired format of the response generated by the model. All
161184
output types available in Outlines are supported provided your
162185
server uses a structured generation backend that supports them.
186+
tools
187+
The tools to use for the generation.
163188
inference_kwargs
164189
Additional keyword arguments to pass to the client.
165190
166191
Returns
167192
-------
168-
Iterator[str]
193+
Iterator[StreamingOutput]
169194
An iterator that yields the text generated by the model.
170195
171196
"""
197+
self.type_adapter.format_tools(tools)
172198
client_args = self._build_client_args(
173199
model_input, output_type, **inference_kwargs,
174200
)
@@ -179,12 +205,12 @@ def generate_stream(
179205

180206
for chunk in stream: # pragma: no cover
181207
if chunk.choices and chunk.choices[0].delta.content is not None:
182-
yield chunk.choices[0].delta.content
208+
yield StreamingOutput(content=chunk.choices[0].delta.content)
183209

184210
def _build_client_args(
185211
self,
186212
model_input: Union[Chat, str, list],
187-
output_type: Optional[Any] = None,
213+
output_type: Optional[Any],
188214
**inference_kwargs: Any,
189215
) -> dict:
190216
"""Build the arguments to pass to the OpenAI client."""
@@ -234,9 +260,10 @@ def __init__(
234260
async def generate(
235261
self,
236262
model_input: Union[Chat, str, list],
237-
output_type: Optional[Any] = None,
263+
output_type: Optional[Any],
264+
tools: Optional[List[ToolDef]],
238265
**inference_kwargs: Any,
239-
) -> Union[str, list[str]]:
266+
) -> Union[Output, list[Output]]:
240267
"""Generate text using vLLM.
241268
242269
Parameters
@@ -247,12 +274,14 @@ async def generate(
247274
The desired format of the response generated by the model. All
248275
output types available in Outlines are supported provided your
249276
server uses a structured generation backend that supports them.
277+
tools
278+
The tools to use for the generation.
250279
inference_kwargs
251280
Additional keyword arguments to pass to the client.
252281
253282
Returns
254283
-------
255-
Union[str, list[str]]
284+
Union[Output, list[Output]]
256285
The text generated by the model.
257286
258287
"""
@@ -271,24 +300,26 @@ async def generate(
271300
)
272301

273302
if len(messages) == 1:
274-
return messages[0].content
303+
return Output(content=messages[0].content)
275304
else:
276-
return [message.content for message in messages]
305+
return [Output(content=message.content) for message in messages]
277306

278307
async def generate_batch(
279308
self,
280309
model_input,
281-
output_type = None,
310+
output_type,
311+
tools,
282312
**inference_kwargs,
283313
):
284314
raise NotImplementedError("VLLM does not support batch inference.")
285315

286316
async def generate_stream( # type: ignore
287317
self,
288318
model_input: Union[Chat, str, list],
289-
output_type: Optional[Any] = None,
319+
output_type: Optional[Any],
320+
tools: Optional[List[ToolDef]],
290321
**inference_kwargs: Any,
291-
) -> AsyncIterator[str]:
322+
) -> AsyncIterator[StreamingOutput]:
292323
"""Stream text using vLLM.
293324
294325
Parameters
@@ -299,13 +330,16 @@ async def generate_stream( # type: ignore
299330
The desired format of the response generated by the model. All
300331
output types available in Outlines are supported provided your
301332
server uses a structured generation backend that supports them.
333+
tools
334+
The tools to use for the generation.
302335
inference_kwargs
303336
Additional keyword arguments to pass to the client.
304337
305338
Returns
306339
-------
307-
AsyncIterator[str]
340+
AsyncIterator[StreamingOutput]
308341
An async iterator that yields the text generated by the model.
342+
309343
"""
310344
client_args = self._build_client_args(
311345
model_input, output_type, **inference_kwargs,
@@ -318,12 +352,12 @@ async def generate_stream( # type: ignore
318352

319353
async for chunk in stream: # pragma: no cover
320354
if chunk.choices and chunk.choices[0].delta.content is not None:
321-
yield chunk.choices[0].delta.content
355+
yield StreamingOutput(content=chunk.choices[0].delta.content)
322356

323357
def _build_client_args(
324358
self,
325359
model_input: Union[Chat, str, list],
326-
output_type: Optional[Any] = None,
360+
output_type: Optional[Any],
327361
**inference_kwargs: Any,
328362
) -> dict:
329363
"""Build the arguments to pass to the OpenAI client."""

tests/models/test_vllm.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from outlines.inputs import Chat, Image
1313
from outlines.models.vllm import VLLM, AsyncVLLM, from_vllm
14+
from outlines.outputs import Output, StreamingOutput
1415
from outlines.types.dsl import CFG, Regex, JsonSchema
1516
from tests.test_utils.mock_openai_client import MockOpenAIClient, MockAsyncOpenAIClient
1617

@@ -225,7 +226,7 @@ def test_vllm_init():
225226

226227
def test_vllm_sync_simple_call(sync_model):
227228
result = sync_model("Respond with a single word.",)
228-
assert isinstance(result, str)
229+
assert isinstance(result, Output)
229230

230231

231232
def test_vllm_sync_streaming(sync_model_no_model_name):
@@ -234,7 +235,7 @@ def test_vllm_sync_streaming(sync_model_no_model_name):
234235
model=vllm_model_name,
235236
)
236237
assert isinstance(result, Generator)
237-
assert isinstance(next(result), str)
238+
assert isinstance(next(result), StreamingOutput)
238239

239240

240241
def test_vllm_sync_batch(sync_model):
@@ -246,7 +247,7 @@ def test_vllm_sync_batch(sync_model):
246247

247248
def test_vllm_sync_vision(sync_model):
248249
result = sync_model(["hello", image_input], max_tokens=10)
249-
assert isinstance(result, str)
250+
assert isinstance(result, Output)
250251

251252

252253
def test_vllm_sync_vision_chat(sync_model):
@@ -261,40 +262,40 @@ def test_vllm_sync_vision_chat(sync_model):
261262
]),
262263
max_tokens=10,
263264
)
264-
assert isinstance(result, str)
265+
assert isinstance(result, Output)
265266

266267

267268
def test_vllm_sync_multiple_samples(sync_model):
268269
result = sync_model("Respond with a single word.", n=2)
269270
assert isinstance(result, list)
270271
assert len(result) == 2
271-
assert isinstance(result[0], str)
272-
assert isinstance(result[1], str)
272+
assert isinstance(result[0], Output)
273+
assert isinstance(result[1], Output)
273274

274275

275276
def test_vllm_sync_json(sync_model):
276277
json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}}'
277278
result = sync_model("foo?", JsonSchema(json_string), max_tokens=10)
278-
assert isinstance(result, str)
279-
assert "bar" in result
279+
assert isinstance(result, Output)
280+
assert "bar" in result.content
280281

281282

282283
def test_vllm_sync_regex(sync_model):
283284
result = sync_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10)
284-
assert isinstance(result, str)
285-
assert re.match(r"[0-9]{3}", result)
285+
assert isinstance(result, Output)
286+
assert re.match(r"[0-9]{3}", result.content)
286287

287288

288289
def test_vllm_sync_cfg(sync_model):
289290
result = sync_model("foo?", CFG(YES_NO_GRAMMAR), max_tokens=10)
290-
assert isinstance(result, str)
291-
assert result in ["yes", "no"]
291+
assert isinstance(result, Output)
292+
assert result.content in ["yes", "no"]
292293

293294

294295
@pytest.mark.asyncio
295296
async def test_vllm_async_simple_call(async_model):
296297
result = await async_model("Respond with a single word.",)
297-
assert isinstance(result, str)
298+
assert isinstance(result, Output)
298299

299300

300301
@pytest.mark.asyncio
@@ -305,7 +306,7 @@ async def test_vllm_async_streaming(async_model_no_model_name):
305306
)
306307
assert isinstance(result, AsyncGenerator)
307308
async for chunk in result:
308-
assert isinstance(chunk, str)
309+
assert isinstance(chunk, StreamingOutput)
309310
break # Just check the first chunk
310311

311312

@@ -320,7 +321,7 @@ async def test_vllm_async_batch(async_model):
320321
@pytest.mark.asyncio
321322
async def test_vllm_async_vision(async_model):
322323
result = await async_model(["hello", image_input], max_tokens=10)
323-
assert isinstance(result, str)
324+
assert isinstance(result, Output)
324325

325326

326327
@pytest.mark.asyncio
@@ -336,35 +337,35 @@ async def test_vllm_async_vision_chat(async_model):
336337
]),
337338
max_tokens=10,
338339
)
339-
assert isinstance(result, str)
340+
assert isinstance(result, Output)
340341

341342

342343
@pytest.mark.asyncio
343344
async def test_vllm_async_multiple_samples(async_model):
344345
result = await async_model("Respond with a single word.", n=2)
345346
assert isinstance(result, list)
346347
assert len(result) == 2
347-
assert isinstance(result[0], str)
348-
assert isinstance(result[1], str)
348+
assert isinstance(result[0], Output)
349+
assert isinstance(result[1], Output)
349350

350351

351352
@pytest.mark.asyncio
352353
async def test_vllm_async_json(async_model):
353354
json_string = '{"type": "object", "properties": {"bar": {"type": "string"}}}'
354355
result = await async_model("foo?", JsonSchema(json_string), max_tokens=10)
355-
assert isinstance(result, str)
356-
assert "bar" in result
356+
assert isinstance(result, Output)
357+
assert "bar" in result.content
357358

358359

359360
@pytest.mark.asyncio
360361
async def test_vllm_async_regex(async_model):
361362
result = await async_model("foo?", Regex(r"[0-9]{3}"), max_tokens=10)
362-
assert isinstance(result, str)
363-
assert re.match(r"[0-9]{3}", result)
363+
assert isinstance(result, Output)
364+
assert re.match(r"[0-9]{3}", result.content)
364365

365366

366367
@pytest.mark.asyncio
367368
async def test_vllm_async_cfg(async_model):
368369
result = await async_model("foo?", CFG(YES_NO_GRAMMAR), max_tokens=10)
369-
assert isinstance(result, str)
370-
assert result in ["yes", "no"]
370+
assert isinstance(result, Output)
371+
assert result.content in ["yes", "no"]

0 commit comments

Comments
 (0)