Skip to content

Commit 3e7f9a6

Browse files
committed
Implement tools and outputs for the Transformers model
1 parent 9955e1e commit 3e7f9a6

File tree

5 files changed

+127
-45
lines changed

5 files changed

+127
-45
lines changed

outlines/models/transformers.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from outlines.inputs import Audio, Chat, Image, Video
1010
from outlines.models.base import Model, ModelTypeAdapter
1111
from outlines.models.tokenizer import Tokenizer
12+
from outlines.outputs import Output
1213
from outlines.processors import OutlinesLogitsProcessor
14+
from outlines.tools import ToolDef
1315

1416
if TYPE_CHECKING:
1517
import torch
@@ -173,7 +175,7 @@ def format_chat_input(self, model_input: Chat) -> str:
173175

174176
def format_output_type(
175177
self,
176-
output_type: Optional[OutlinesLogitsProcessor] = None,
178+
output_type: Optional[OutlinesLogitsProcessor],
177179
) -> Optional["LogitsProcessorList"]:
178180
"""Generate the logits processor argument to pass to the model.
179181
@@ -194,6 +196,13 @@ def format_output_type(
194196
return LogitsProcessorList([output_type])
195197
return None
196198

199+
def format_tools(self, tools):
200+
"""Not available for Transformers."""
201+
if tools:
202+
raise NotImplementedError(
203+
"Transformers does not support tools."
204+
)
205+
197206

198207
class Transformers(Model):
199208
"""Thin wrapper around a `transformers` model and a `transformers`
@@ -295,9 +304,10 @@ def _prepare_model_inputs(
295304
def generate(
296305
self,
297306
model_input: Union[str, dict, Chat],
298-
output_type: Optional[OutlinesLogitsProcessor] = None,
307+
output_type: Optional[OutlinesLogitsProcessor],
308+
tools: Optional[List[ToolDef]],
299309
**inference_kwargs: Any,
300-
) -> Union[str, List[str]]:
310+
) -> Output | List[Output]:
301311
"""Generate text using `transformers`.
302312
303313
Parameters
@@ -310,16 +320,19 @@ def generate(
310320
output_type
311321
The logits processor the model will use to constrain the format of
312322
the generated text.
323+
tools
324+
The tools to use for the generation.
313325
inference_kwargs
314326
Additional keyword arguments to pass to the `generate` method
315327
of the `transformers` model.
316328
317329
Returns
318330
-------
319-
Union[str, List[str]]
331+
Output | List[Output]
320332
The text generated by the model.
321333
322334
"""
335+
self.type_adapter.format_tools(tools)
323336
prompts, inputs = self._prepare_model_inputs(model_input, False)
324337
logits_processor = self.type_adapter.format_output_type(output_type)
325338

@@ -336,15 +349,39 @@ def generate(
336349
if num_samples == 1 and len(generated_ids.shape) == 2:
337350
generated_ids = generated_ids.squeeze(0)
338351

339-
return self._decode_generation(generated_ids)
352+
generated_text = self._decode_generation(generated_ids)
353+
354+
if isinstance(generated_text, list):
355+
return [Output(content=text) for text in generated_text]
356+
return Output(content=generated_text)
340357

341358
def generate_batch(
342359
self,
343360
model_input: List[Union[str, dict, Chat]],
344-
output_type: Optional[OutlinesLogitsProcessor] = None,
361+
output_type: Optional[OutlinesLogitsProcessor],
362+
tools: Optional[List[ToolDef]],
345363
**inference_kwargs: Any,
346-
) -> List[Union[str, List[str]]]:
347-
""""""
364+
) -> List[Output] | List[List[Output]]:
365+
"""Generate a batch of completions using `transformers`.
366+
367+
Parameters
368+
----------
369+
model_input
370+
The list of prompts based on which the model will generate a response.
371+
output_type
372+
The logits processor the model will use to constrain the format of the generated text.
373+
tools
374+
The tools to use for the generation.
375+
**inference_kwargs
376+
Additional keyword arguments to pass to the `generate` method of the `transformers` model.
377+
378+
Returns
379+
-------
380+
List[Output] | List[List[Output]]
381+
The list of text generated by the model.
382+
383+
"""
384+
self.type_adapter.format_tools(tools)
348385
prompts, inputs = self._prepare_model_inputs(model_input, True) # type: ignore
349386
logits_processor = self.type_adapter.format_output_type(output_type)
350387

@@ -357,7 +394,17 @@ def generate_batch(
357394
if num_samples > 1:
358395
generated_ids = generated_ids.view(len(model_input), num_samples, -1)
359396

360-
return self._decode_generation(generated_ids)
397+
generated_text = self._decode_generation(generated_ids)
398+
399+
return [ # type: ignore
400+
[
401+
Output(content=text)
402+
for text in batch
403+
]
404+
if isinstance(batch, list)
405+
else Output(content=batch)
406+
for batch in generated_text
407+
]
361408

362409
def generate_stream(self, model_input, output_type, **inference_kwargs):
363410
"""Not available for `transformers` models.
@@ -369,7 +416,7 @@ def generate_stream(self, model_input, output_type, **inference_kwargs):
369416
"Streaming is not implemented for Transformers models."
370417
)
371418

372-
def _generate_output_seq(self, prompts, inputs, **inference_kwargs):
419+
def _generate_output_seq(self, prompts, inputs, **inference_kwargs): # type: ignore
373420
input_ids = inputs["input_ids"]
374421

375422
output_ids = self.model.generate(
@@ -472,7 +519,7 @@ def format_chat_input(self, model_input: Chat) -> dict:
472519
"content": message["content"][0],
473520
})
474521
else:
475-
messages_without_images.append(message)
522+
messages_without_images.append(message) # type: ignore
476523
formatted_prompt = self.tokenizer.apply_chat_template(
477524
messages_without_images,
478525
tokenize=False
@@ -513,7 +560,7 @@ def format_list_input(self, model_input: list) -> dict:
513560

514561
def format_output_type(
515562
self,
516-
output_type: Optional[OutlinesLogitsProcessor] = None,
563+
output_type: Optional[OutlinesLogitsProcessor],
517564
) -> Optional["LogitsProcessorList"]:
518565
"""Generate the logits processor argument to pass to the model.
519566
@@ -534,6 +581,13 @@ def format_output_type(
534581
return LogitsProcessorList([output_type])
535582
return None
536583

584+
def format_tools(self, tools):
585+
"""Not available for TransformersMultiModal."""
586+
if tools:
587+
raise NotImplementedError(
588+
"TransformersMultiModal does not support tools."
589+
)
590+
537591

538592
class TransformersMultiModal(Transformers):
539593
"""Thin wrapper around a `transformers` model and a `transformers`

tests/models/test_transformers.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TransformerTokenizer,
1313
TransformersTypeAdapter,
1414
)
15+
from outlines.outputs import Output, StreamingOutput
1516
from outlines.types import Regex
1617

1718

@@ -80,16 +81,16 @@ def model_bart():
8081

8182

8283
def test_transformers_simple(model):
83-
result = model.generate("Respond with one word. Not more.", None)
84-
assert isinstance(result, str)
84+
result = model("Respond with one word. Not more.", None)
85+
assert isinstance(result, Output)
8586

8687

8788
def test_transformers_call(model, model_bart):
8889
result = model("Respond with one word. Not more.")
89-
assert isinstance(result, str)
90+
assert isinstance(result, Output)
9091

9192
result = model_bart("Respond with one word. Not more.")
92-
assert isinstance(result, str)
93+
assert isinstance(result, Output)
9394

9495

9596
def test_transformers_chat(model):
@@ -99,12 +100,12 @@ def test_transformers_chat(model):
99100
{"role": "user", "content": "What is the capital of France?"},
100101
])
101102
)
102-
assert isinstance(result, str)
103+
assert isinstance(result, Output)
103104

104105

105106
def test_transformers_inference_kwargs(model):
106107
result = model("Respond with one word. Not more.", max_new_tokens=100)
107-
assert isinstance(result, str)
108+
assert isinstance(result, Output)
108109

109110

110111
def test_transformers_invalid_inference_kwargs(model):
@@ -114,16 +115,16 @@ def test_transformers_invalid_inference_kwargs(model):
114115

115116
def test_transformers_regex(model):
116117
result = model("Give a number between 0 and 9.", Regex(r"[0-9]"))
117-
assert isinstance(result, str)
118-
assert re.match(r"[0-9]", result)
118+
assert isinstance(result, Output)
119+
assert re.match(r"[0-9]", result.content)
119120

120121

121122
def test_transformers_json(model):
122123
class Character(BaseModel):
123124
name: str
124125

125126
result = model("Create a character with a name.", Character)
126-
assert "name" in result
127+
assert "name" in result.content
127128

128129

129130
def test_transformers_choice(model):
@@ -132,12 +133,12 @@ class Foo(Enum):
132133
dog = "dog"
133134

134135
result = model("Cat or dog?", Foo)
135-
assert result in ["cat", "dog"]
136+
assert result.content in ["cat", "dog"]
136137

137138

138139
def test_transformers_multiple_samples(model):
139140
result = model("Respond with one word. Not more.")
140-
assert isinstance(result, str)
141+
assert isinstance(result, Output)
141142
result = model(
142143
"Respond with one word. Not more.", num_return_sequences=2, do_sample=True
143144
)
@@ -187,8 +188,8 @@ class Foo(Enum):
187188
result = model("Cat or dog?", Foo, num_return_sequences=2, do_sample=True)
188189
assert isinstance(result, list)
189190
assert len(result) == 2
190-
assert result[0] in ["cat", "dog"]
191-
assert result[1] in ["cat", "dog"]
191+
assert result[0].content in ["cat", "dog"]
192+
assert result[1].content in ["cat", "dog"]
192193

193194

194195
def test_transformers_batch_constrained(model):
@@ -202,8 +203,8 @@ class Foo(Enum):
202203
)
203204
assert isinstance(result, list)
204205
assert len(result) == 2
205-
assert result[0] in ["cat", "dog"]
206-
assert result[1] in ["cat", "dog"]
206+
assert result[0].content in ["cat", "dog"]
207+
assert result[1].content in ["cat", "dog"]
207208

208209
result = model.batch(
209210
["Cat or dog?", "Cat or dog?"],
@@ -216,8 +217,8 @@ class Foo(Enum):
216217
for item in result:
217218
assert isinstance(item, list)
218219
assert len(item) == 2
219-
assert item[0] in ["cat", "dog"]
220-
assert item[1] in ["cat", "dog"]
220+
assert item[0].content in ["cat", "dog"]
221+
assert item[1].content in ["cat", "dog"]
221222

222223

223224
def test_transformers_streaming(model):

tests/models/test_transformers_multimodal.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
TransformerTokenizer,
2020
TransformersMultiModalTypeAdapter,
2121
)
22+
from outlines.outputs import Output, StreamingOutput
2223
from outlines.types import Regex
2324

2425
TEST_MODEL = "trl-internal-testing/tiny-LlavaForConditionalGeneration"
@@ -61,20 +62,20 @@ def test_transformers_multimodal_instantiate_simple():
6162

6263

6364
def test_transformers_multimodal_simple(model, image):
64-
result = model.generate(
65+
result = model(
6566
["<image>Describe this image in one sentence:", Image(image)],
6667
None,
6768
max_new_tokens=2,
6869
)
69-
assert isinstance(result, str)
70+
assert isinstance(result, Output)
7071

7172

7273
def test_transformers_multimodal_call(model, image):
7374
result = model(
7475
["<image>Describe this image in one sentence:", Image(image)],
7576
max_new_tokens=2,
7677
)
77-
assert isinstance(result, str)
78+
assert isinstance(result, Output)
7879

7980

8081
def test_transformers_multimodal_wrong_number_image(model, image):
@@ -90,7 +91,7 @@ def test_transformers_multimodal_wrong_number_image(model, image):
9091

9192
def test_transformers_multimodal_wrong_input_type(model):
9293
with pytest.raises(TypeError):
93-
model.generate("invalid input", None)
94+
model("invalid input", None)
9495

9596

9697
def test_transformers_multimodal_chat(model, image):
@@ -107,15 +108,15 @@ def test_transformers_multimodal_chat(model, image):
107108
]),
108109
max_new_tokens=2,
109110
)
110-
assert isinstance(result, str)
111+
assert isinstance(result, Output)
111112

112113

113114
def test_transformers_inference_kwargs(model, image):
114115
result = model(
115116
["<image>Describe this image in one sentence:", Image(image)],
116117
max_new_tokens=2,
117118
)
118-
assert isinstance(result, str)
119+
assert isinstance(result, Output)
119120

120121

121122
def test_transformers_invalid_inference_kwargs(model, image):
@@ -138,7 +139,7 @@ def test_transformers_several_image(model, image):
138139
],
139140
max_new_tokens=2,
140141
)
141-
assert isinstance(result, str)
142+
assert isinstance(result, Output)
142143

143144

144145
def test_transformers_multimodal_json(model, image):
@@ -150,7 +151,8 @@ class Foo(BaseModel):
150151
Foo,
151152
max_new_tokens=10,
152153
)
153-
assert "name" in result
154+
assert isinstance(result, Output)
155+
assert "name" in result.content
154156

155157

156158
def test_transformers_multimodal_regex(model, image):
@@ -159,8 +161,8 @@ def test_transformers_multimodal_regex(model, image):
159161
Regex(r"[0-9]")
160162
)
161163

162-
assert isinstance(result, str)
163-
assert re.match(r"[0-9]", result)
164+
assert isinstance(result, Output)
165+
assert re.match(r"[0-9]", result.content)
164166

165167

166168
def test_transformers_multimodal_choice(model, image):
@@ -173,8 +175,8 @@ class Foo(Enum):
173175
Foo,
174176
)
175177

176-
assert isinstance(result, str)
177-
assert result in ["white", "blue"]
178+
assert isinstance(result, Output)
179+
assert result.content in ["white", "blue"]
178180

179181

180182
def test_transformers_multimodal_multiple_samples(model, image):
@@ -245,12 +247,12 @@ def test_transformers_multimodal_batch(model, image):
245247

246248
def test_transformers_multimodal_deprecated_input_type(model, image):
247249
with pytest.warns(DeprecationWarning):
248-
result = model.generate(
250+
result = model(
249251
{
250252
"text": "<image>Describe this image in one sentence:",
251253
"image": image,
252254
},
253255
None,
254256
max_new_tokens=2,
255257
)
256-
assert isinstance(result, str)
258+
assert isinstance(result, Output)

0 commit comments

Comments
 (0)