Skip to content

Commit a2f0fe6

Browse files
committed
Draft
1 parent 778cd10 commit a2f0fe6

File tree

14 files changed

+1520
-229
lines changed

14 files changed

+1520
-229
lines changed

outlines/generator.py

Lines changed: 95 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
get_regex_logits_processor,
2222
)
2323
from outlines.backends.base import LogitsProcessorType
24+
from outlines.outputs import Output, StreamingOutput
25+
from outlines.tools import get_formatted_tools, ToolsInput
2426
from outlines.types import CFG, JsonSchema
2527
from outlines.types.dsl import python_types_to_terms, to_regex
2628

@@ -35,20 +37,30 @@ class BlackBoxGenerator:
3537
"""
3638
output_type: Optional[Any]
3739

38-
def __init__(self, model: BlackBoxModel, output_type: Optional[Any]):
40+
def __init__(
41+
self,
42+
model: BlackBoxModel,
43+
output_type: Optional[Any],
44+
*,
45+
tools: Optional[ToolsInput] = None,
46+
):
3947
"""
4048
Parameters
4149
----------
4250
model
4351
An instance of an Outlines model.
4452
output_type
4553
The output type that will be used to constrain the generation.
54+
tools
55+
A list of tools to use for the generator. Can contain an MCPServer,
56+
a list of ToolDef, Callable, or BaseModel instances.
4657
4758
"""
4859
self.model = model
4960
self.output_type = output_type
61+
self.tools = get_formatted_tools(tools)
5062

51-
def __call__(self, prompt: Any, **inference_kwargs) -> Any:
63+
def __call__(self, prompt: Any, **inference_kwargs) -> Output | List[Output]:
5264
"""Generate a response from the model.
5365
5466
Parameters
@@ -60,15 +72,17 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
6072
6173
Returns
6274
-------
63-
Any
64-
The response generated by the model.
75+
Output | List[Output]
76+
The output generated by the model.
6577
6678
"""
6779
return self.model.generate(
68-
prompt, self.output_type, **inference_kwargs
80+
prompt, self.output_type, tools=self.tools, **inference_kwargs
6981
)
7082

71-
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
83+
def batch(
84+
self, prompts: List[Any], **inference_kwargs
85+
) -> List[Output] | List[List[Output]]:
7286
"""Generate a batch of responses from the model.
7387
7488
Parameters
@@ -80,15 +94,17 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
8094
8195
Returns
8296
-------
83-
List[Any]
84-
The list of responses generated by the model.
97+
List[Output] | List[List[Output]]
98+
The list of outputs generated by the model.
8599
86100
"""
87101
return self.model.generate_batch(
88-
prompts, self.output_type, **inference_kwargs
102+
prompts, self.output_type, tools=self.tools, **inference_kwargs
89103
)
90104

91-
def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
105+
def stream(
106+
self, prompt: Any, **inference_kwargs
107+
) -> Iterator[StreamingOutput]:
92108
"""Generate a stream of responses from the model.
93109
94110
Parameters
@@ -100,12 +116,12 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
100116
101117
Returns
102118
-------
103-
Any
104-
The response generated by the model.
119+
Iterator[StreamingOutput]
120+
A stream of StreamingOutput generated by the model.
105121
106122
"""
107123
return self.model.generate_stream(
108-
prompt, self.output_type, **inference_kwargs
124+
prompt, self.output_type, tools=self.tools, **inference_kwargs
109125
)
110126

111127

@@ -119,20 +135,32 @@ class AsyncBlackBoxGenerator:
119135
"""
120136
output_type: Optional[Any]
121137

122-
def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]):
138+
def __init__(
139+
self,
140+
model: AsyncBlackBoxModel,
141+
output_type: Optional[Any],
142+
*,
143+
tools: Optional[ToolsInput] = None,
144+
):
123145
"""
124146
Parameters
125147
----------
126148
model
127149
An instance of an Outlines model.
128150
output_type
129151
The output type that will be used to constrain the generation.
152+
tools
153+
A list of tools to use for the generator. Can contain an MCPServer,
154+
a list of ToolDef, Callable, or BaseModel instances.
130155
131156
"""
132157
self.model = model
133158
self.output_type = output_type
159+
self.tools = get_formatted_tools(tools)
134160

135-
async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
161+
async def __call__(
162+
self, prompt: Any, **inference_kwargs
163+
) -> Output | List[Output]:
136164
"""Generate a response from the model.
137165
138166
Parameters
@@ -144,15 +172,17 @@ async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
144172
145173
Returns
146174
-------
147-
Any
148-
The response generated by the model.
175+
Output | List[Output]
176+
The output generated by the model.
149177
150178
"""
151179
return await self.model.generate(
152-
prompt, self.output_type, **inference_kwargs
180+
prompt, self.output_type, tools=self.tools, **inference_kwargs
153181
)
154182

155-
async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
183+
async def batch(
184+
self, prompts: List[Any], **inference_kwargs
185+
) -> List[Output] | List[List[Output]]:
156186
"""Generate a batch of responses from the model.
157187
158188
Parameters
@@ -164,15 +194,17 @@ async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
164194
165195
Returns
166196
-------
167-
List[Any]
168-
The list of responses generated by the model.
197+
List[Output] | List[List[Output]]
198+
The list of outputs generated by the model.
169199
170200
"""
171201
return await self.model.generate_batch(
172-
prompts, self.output_type, **inference_kwargs
202+
prompts, self.output_type, tools=self.tools, **inference_kwargs
173203
)
174204

175-
async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
205+
async def stream(
206+
self, prompt: Any, **inference_kwargs
207+
) -> AsyncIterator[StreamingOutput]:
176208
"""Generate a stream of responses from the model.
177209
178210
Parameters
@@ -184,12 +216,13 @@ async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
184216
185217
Returns
186218
-------
187-
Any
188-
The response generated by the model.
219+
AsyncIterator[StreamingOutput]
220+
A coroutine that will produce an async iterator of StreamingOutput
221+
produced by the model.
189222
190223
"""
191224
async for chunk in self.model.generate_stream( # pragma: no cover
192-
prompt, self.output_type, **inference_kwargs
225+
prompt, self.output_type, tools=self.tools, **inference_kwargs
193226
):
194227
yield chunk
195228

@@ -218,6 +251,8 @@ def __init__(
218251
model: SteerableModel,
219252
output_type: Optional[Any],
220253
backend_name: Optional[str] = None,
254+
*,
255+
tools: Optional[ToolsInput] = None,
221256
):
222257
"""
223258
Parameters
@@ -228,9 +263,13 @@ def __init__(
228263
The output type expressed as a Python type
229264
backend_name
230265
The name of the backend to use to create the logits processor.
266+
tools
267+
A list of tools to use for the generator. Can contain an MCPServer,
268+
a list of ToolDef, Callable, or BaseModel instances.
231269
232270
"""
233271
self.model = model
272+
self.tools = get_formatted_tools(tools)
234273
if output_type is None:
235274
self.logits_processor = None
236275
else:
@@ -258,7 +297,11 @@ def __init__(
258297

259298
@classmethod
260299
def from_processor(
261-
cls, model: SteerableModel, processor: LogitsProcessorType
300+
cls,
301+
model: SteerableModel,
302+
processor: LogitsProcessorType,
303+
*,
304+
tools: Optional[ToolsInput] = None,
262305
):
263306
"""Create a generator from a logits processor.
264307
@@ -270,13 +313,12 @@ def from_processor(
270313
An instance of a logits processor.
271314
272315
"""
273-
instance = cls.__new__(cls)
274-
instance.model = model
316+
instance = cls(model, None, tools=tools)
275317
instance.logits_processor = processor
276318

277319
return instance
278320

279-
def __call__(self, prompt: Any, **inference_kwargs) -> Any:
321+
def __call__(self, prompt: Any, **inference_kwargs) -> Output | List[Output]:
280322
"""Generate a response from the model.
281323
282324
Parameters
@@ -288,17 +330,19 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
288330
289331
Returns
290332
-------
291-
Any
292-
The response generated by the model.
333+
Output | List[Output]
334+
The output generated by the model.
293335
294336
"""
295337
if self.logits_processor is not None:
296338
self.logits_processor.reset()
297339
return self.model.generate(
298-
prompt, self.logits_processor, **inference_kwargs
340+
prompt, self.logits_processor, tools=self.tools, **inference_kwargs
299341
)
300342

301-
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
343+
def batch(
344+
self, prompts: List[Any], **inference_kwargs
345+
) -> List[Output] | List[List[Output]]:
302346
"""Generate a batch of responses from the model.
303347
304348
Parameters
@@ -310,17 +354,19 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
310354
311355
Returns
312356
-------
313-
List[Any]
314-
The list of responses generated by the model.
357+
List[Output] | List[List[Output]]
358+
The list of outputs generated by the model.
315359
316360
"""
317361
if self.logits_processor is not None:
318362
self.logits_processor.reset()
319363
return self.model.generate_batch(
320-
prompts, self.logits_processor, **inference_kwargs
364+
prompts, self.logits_processor, tools=self.tools, **inference_kwargs
321365
)
322366

323-
def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
367+
def stream(
368+
self, prompt: Any, **inference_kwargs
369+
) -> Iterator[StreamingOutput]:
324370
"""Generate a stream of responses from the model.
325371
326372
Parameters
@@ -332,14 +378,14 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
332378
333379
Returns
334380
-------
335-
Any
336-
The response generated by the model.
381+
Iterator[StreamingOutput]
382+
A stream of StreamingOutput generated by the model.
337383
338384
"""
339385
if self.logits_processor is not None:
340386
self.logits_processor.reset()
341387
return self.model.generate_stream(
342-
prompt, self.logits_processor, **inference_kwargs
388+
prompt, self.logits_processor, tools=self.tools, **inference_kwargs
343389
)
344390

345391

@@ -348,6 +394,7 @@ def Generator(
348394
output_type: Optional[Any] = None,
349395
backend: Optional[str] = None,
350396
*,
397+
tools: Optional[ToolsInput] = None,
351398
processor: Optional[LogitsProcessorType] = None,
352399
) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]:
353400
"""Create a generator for the given model and output parameters.
@@ -367,6 +414,9 @@ def Generator(
367414
The name of the backend to use to create the logits processor. Only
368415
used for steerable models if there is an output type and `processor` is
369416
not provided.
417+
tools
418+
A list of tools to use for the generator. Can contain an MCPServer,
419+
a list of ToolDef, Callable, or BaseModel instances.
370420
processor
371421
An instance of a logits processor.
372422
@@ -387,18 +437,18 @@ def Generator(
387437

388438
if isinstance(model, SteerableModel): # type: ignore
389439
if processor is not None:
390-
return SteerableGenerator.from_processor(model, processor) # type: ignore
440+
return SteerableGenerator.from_processor(model, processor, tools=tools) # type: ignore
391441
else:
392-
return SteerableGenerator(model, output_type, backend) # type: ignore
442+
return SteerableGenerator(model, output_type, backend, tools=tools) # type: ignore
393443
else:
394444
if processor is not None:
395445
raise NotImplementedError(
396446
"This model does not support logits processors"
397447
)
398448
if isinstance(model, AsyncBlackBoxModel): # type: ignore
399-
return AsyncBlackBoxGenerator(model, output_type) # type: ignore
449+
return AsyncBlackBoxGenerator(model, output_type, tools=tools) # type: ignore
400450
elif isinstance(model, BlackBoxModel): # type: ignore
401-
return BlackBoxGenerator(model, output_type) # type: ignore
451+
return BlackBoxGenerator(model, output_type, tools=tools) # type: ignore
402452
else:
403453
raise ValueError(
404454
"The model argument must be an instance of "

0 commit comments

Comments
 (0)