Skip to content

Commit a6b231f

Browse files
committed
Implement tools and outputs for the VLLMOffline model
1 parent b3f5032 commit a6b231f

File tree

3 files changed

+61
-24
lines changed

3 files changed

+61
-24
lines changed

outlines/models/vllm_offline.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from outlines.inputs import Chat
88
from outlines.models.base import Model, ModelTypeAdapter
99
from outlines.models.openai import OpenAITypeAdapter
10+
from outlines.outputs import Output
11+
from outlines.tools import ToolDef
1012
from outlines.types.dsl import CFG, JsonSchema, python_types_to_terms, to_regex
1113

1214
if TYPE_CHECKING:
@@ -56,7 +58,7 @@ def format_input_chat(self, model_input: Chat) -> list:
5658
)
5759
return OpenAITypeAdapter().format_input(model_input)
5860

59-
def format_output_type(self, output_type: Optional[Any] = None) -> dict:
61+
def format_output_type(self, output_type: Optional[Any]) -> dict:
6062
"""Generate the structured output argument to pass to the model.
6163
6264
For vLLM, the structured output definition is set in the
@@ -90,6 +92,13 @@ def format_output_type(self, output_type: Optional[Any] = None) -> dict:
9092
else:
9193
return {"regex": to_regex(term)}
9294

95+
def format_tools(self, tools):
96+
"""Not available for VLLM offline."""
97+
if tools:
98+
raise NotImplementedError(
99+
"Tools are not available for VLLM offline."
100+
)
101+
93102

94103
class VLLMOffline(Model):
95104
"""Thin wrapper around a `vllm.LLM` model.
@@ -114,7 +123,7 @@ def __init__(self, model: "LLM"):
114123
def _build_generation_args(
115124
self,
116125
inference_kwargs: dict,
117-
output_type: Optional[Any] = None,
126+
output_type: Optional[Any],
118127
) -> "SamplingParams":
119128
"""Create the `SamplingParams` object to pass to the `generate` method
120129
of the `vllm.LLM` model."""
@@ -134,9 +143,10 @@ def _build_generation_args(
134143
def generate(
135144
self,
136145
model_input: Chat | str,
137-
output_type: Optional[Any] = None,
146+
output_type: Optional[Any],
147+
tools: Optional[List[ToolDef]],
138148
**inference_kwargs: Any,
139-
) -> Union[str, List[str]]:
149+
) -> Union[Output, List[Output]]:
140150
"""Generate text using vLLM offline.
141151
142152
Parameters
@@ -146,16 +156,19 @@ def generate(
146156
output_type
147157
The logits processor the model will use to constrain the format of
148158
the generated text.
159+
tools
160+
The tools to use for the generation.
149161
inference_kwargs
150162
Additional keyword arguments to pass to the `generate` method
151163
in the `vllm.LLM` model.
152164
153165
Returns
154166
-------
155-
Union[str, List[str]]
167+
Union[Output, List[Output]]
156168
The text generated by the model.
157169
158170
"""
171+
self.type_adapter.format_tools(tools)
159172
sampling_params = self._build_generation_args(
160173
inference_kwargs,
161174
output_type,
@@ -168,24 +181,25 @@ def generate(
168181
**inference_kwargs,
169182
)
170183
else:
171-
results = self.model.generate(
184+
results = self.model(
172185
prompts=self.type_adapter.format_input(model_input),
173186
sampling_params=sampling_params,
174187
**inference_kwargs,
175188
)
176189
results = [completion.text for completion in results[0].outputs]
177190

178191
if len(results) == 1:
179-
return results[0]
192+
return Output(content=results[0])
180193
else:
181-
return results
194+
return [Output(content=result) for result in results]
182195

183196
def generate_batch(
184197
self,
185198
model_input: List[Chat | str],
186-
output_type: Optional[Any] = None,
199+
output_type: Optional[Any],
200+
tools: Optional[List[ToolDef]],
187201
**inference_kwargs: Any,
188-
) -> Union[List[str], List[List[str]]]:
202+
) -> Union[List[Output], List[List[Output]]]:
189203
"""Generate a batch of completions using vLLM offline.
190204
191205
Parameters
@@ -196,16 +210,19 @@ def generate_batch(
196210
output_type
197211
The logits processor the model will use to constrain the format of
198212
the generated text.
213+
tools
214+
The tools to use for the generation.
199215
inference_kwargs
200216
Additional keyword arguments to pass to the `generate` method
201217
in the `vllm.LLM` model.
202218
203219
Returns
204220
-------
205-
Union[List[str], List[List[str]]]
221+
Union[List[Output], List[List[Output]]]
206222
The text generated by the model.
207223
208224
"""
225+
self.type_adapter.format_tools(tools)
209226
sampling_params = self._build_generation_args(
210227
inference_kwargs,
211228
output_type,
@@ -216,14 +233,20 @@ def generate_batch(
216233
"Batch generation is not available for the `Chat` input type."
217234
)
218235

219-
results = self.model.generate(
236+
results = self.model(
220237
prompts=[self.type_adapter.format_input(item) for item in model_input],
221238
sampling_params=sampling_params,
222239
**inference_kwargs,
223240
)
224-
return [[sample.text for sample in batch.outputs] for batch in results]
225241

226-
def generate_stream(self, model_input, output_type, **inference_kwargs):
242+
return [ # type: ignore
243+
[Output(content=sample.text) for sample in batch.outputs]
244+
if len(batch.outputs) > 1
245+
else Output(content=batch.outputs[0].text)
246+
for batch in results
247+
]
248+
249+
def generate_stream(self, model_input, output_type, tools, **inference_kwargs):
227250
"""Not available for `vllm.LLM`.
228251
229252
TODO: Implement the streaming functionality ourselves.

tests/models/test_vllm_offline.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
VLLMOfflineTypeAdapter,
2020
from_vllm_offline
2121
)
22+
from outlines.outputs import Output, StreamingOutput
2223
from outlines.types import Regex
2324

2425

@@ -58,13 +59,13 @@ def model(tmp_path_factory):
5859

5960

6061
def test_vllm_simple(model):
61-
result = model.generate("Respond with one word. Not more.", None)
62-
assert isinstance(result, str)
62+
result = model("Respond with one word. Not more.", None)
63+
assert isinstance(result, Output)
6364

6465

6566
def test_vllm_call(model):
6667
result = model("Respond with one word. Not more.")
67-
assert isinstance(result, str)
68+
assert isinstance(result, Output)
6869

6970

7071
def test_vllm_inference_kwargs(model):
@@ -73,8 +74,8 @@ def test_vllm_inference_kwargs(model):
7374
sampling_params=SamplingParams(max_tokens=2),
7475
use_tqdm=True
7576
)
76-
assert isinstance(result, str)
77-
assert len(result) <= 20
77+
assert isinstance(result, Output)
78+
assert len(result.content) <= 20
7879

7980

8081
def test_vllm_chat(model):
@@ -86,7 +87,7 @@ def test_vllm_chat(model):
8687
]),
8788
sampling_params=SamplingParams(max_tokens=2),
8889
)
89-
assert isinstance(result, str)
90+
assert isinstance(result, Output)
9091

9192

9293
def test_vllm_invalid_inference_kwargs(model):
@@ -96,16 +97,16 @@ def test_vllm_invalid_inference_kwargs(model):
9697

9798
def test_vllm_regex(model):
9899
result = model("Give a number between 0 and 9.", Regex(r"[0-9]"))
99-
assert isinstance(result, str)
100-
assert re.match(r"[0-9]", result)
100+
assert isinstance(result, Output)
101+
assert re.match(r"[0-9]", result.content)
101102

102103

103104
def test_vllm_json(model):
104105
class Character(BaseModel):
105106
name: str
106107

107108
result = model("Create a character with a name.", Character)
108-
assert "name" in result
109+
assert "name" in result.content
109110

110111

111112
def test_vllm_choice(model):
@@ -114,7 +115,7 @@ class Foo(Enum):
114115
dog = "dog"
115116

116117
result = model("Cat or dog?", Foo)
117-
assert result in ["cat", "dog"]
118+
assert result.content in ["cat", "dog"]
118119

119120

120121
def test_vllm_multiple_samples(model):

tests/models/test_vllm_offline_type_adapter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from outlines.inputs import Chat, Image
88
from outlines.models.vllm_offline import VLLMOfflineTypeAdapter
9+
from outlines.tools import ToolDef
910
from outlines.types import CFG, JsonSchema, Regex
1011

1112

@@ -113,3 +114,15 @@ def test_vllm_offline_type_adapter_output_type(
113114
assert type_adapter.format_output_type(regex_instance) == {
114115
"regex": "([0-9]+)"
115116
}
117+
118+
119+
def test_vllm_offline_type_adapter_tools(type_adapter):
120+
with pytest.raises(
121+
NotImplementedError,
122+
match="Tools are not available for VLLM offline."
123+
):
124+
type_adapter.format_tools(
125+
[ToolDef(name="test", description="test", parameters={})]
126+
)
127+
128+
type_adapter.format_tools(None)

0 commit comments

Comments
 (0)