22
33import json
44from functools import singledispatchmethod
5- from typing import TYPE_CHECKING , Any , AsyncIterator , Iterator , Optional , Union
5+ from typing import (
6+ TYPE_CHECKING ,
7+ Any ,
8+ AsyncIterator ,
9+ Iterator ,
10+ List ,
11+ Optional ,
12+ Union ,
13+ )
614
715from pydantic import TypeAdapter
816
917from outlines .inputs import Chat , Image
1018from outlines .models .base import AsyncModel , Model , ModelTypeAdapter
19+ from outlines .outputs import Output , StreamingOutput
20+ from outlines .tools import ToolDef
1121from outlines .types import CFG , JsonSchema , Regex
1222from outlines .types .utils import (
1323 is_dataclass ,
@@ -74,7 +84,7 @@ def format_chat_model_input(self, model_input: Chat) -> list:
7484
7585 """
7686 return [
77- self ._create_message (message ["role" ], message ["content" ])
87+ self ._create_message (message ["role" ], message ["content" ]) # type: ignore
7888 for message in model_input .messages
7989 ]
8090
@@ -107,9 +117,7 @@ def _create_message(self, role: str, content: str | list) -> dict:
107117 "and a list of images."
108118 )
109119
110- def format_output_type (
111- self , output_type : Optional [Any ] = None
112- ) -> Optional [str ]:
120+ def format_output_type (self , output_type : Optional [Any ]) -> Optional [str ]:
113121 """Format the output type to pass to the client.
114122
115123 TODO: `int`, `float` and other Python types could be supported via
@@ -159,6 +167,13 @@ def format_output_type(
159167 "Consider using a local model instead."
160168 )
161169
170+ def format_tools (self , tools ):
171+ """Not available for Ollama."""
172+ if tools :
173+ raise NotImplementedError (
174+ "Tools are not available for Ollama."
175+ )
176+
162177
163178class Ollama (Model ):
164179 """Thin wrapper around the `ollama.Client` client.
@@ -184,9 +199,10 @@ def __init__(self, client: "Client", model_name: Optional[str] = None):
184199
185200 def generate (self ,
186201 model_input : Chat | str | list ,
187- output_type : Optional [Any ] = None ,
202+ output_type : Optional [Any ],
203+ tools : Optional [List [ToolDef ]],
188204 ** kwargs : Any ,
189- ) -> str :
205+ ) -> Output :
190206 """Generate text using Ollama.
191207
192208 Parameters
@@ -197,15 +213,19 @@ def generate(self,
197213 The desired format of the response generated by the model. The
198214 output type must be of a type that can be converted to a JSON
199215 schema.
216+ tools
217+ The tools to use for the generation.
200218 **kwargs
201219 Additional keyword arguments to pass to the client.
202220
203221 Returns
204222 -------
205- str
223+ Output
206224 The text generated by the model.
207225
208226 """
227+ self .type_adapter .format_tools (tools )
228+
209229 if "model" not in kwargs and self .model_name is not None :
210230 kwargs ["model" ] = self .model_name
211231
@@ -214,12 +234,14 @@ def generate(self,
214234 format = self .type_adapter .format_output_type (output_type ),
215235 ** kwargs ,
216236 )
217- return response .message .content
237+
238+ return Output (content = response .message .content )
218239
219240 def generate_batch (
220241 self ,
221242 model_input ,
222- output_type = None ,
243+ output_type ,
244+ tools ,
223245 ** kwargs ,
224246 ):
225247 raise NotImplementedError (
@@ -229,9 +251,10 @@ def generate_batch(
229251 def generate_stream (
230252 self ,
231253 model_input : Chat | str | list ,
232- output_type : Optional [Any ] = None ,
254+ output_type : Optional [Any ],
255+ tools : Optional [List [ToolDef ]],
233256 ** kwargs : Any ,
234- ) -> Iterator [str ]:
257+ ) -> Iterator [StreamingOutput ]:
235258 """Stream text using Ollama.
236259
237260 Parameters
@@ -242,15 +265,19 @@ def generate_stream(
242265 The desired format of the response generated by the model. The
243266 output type must be of a type that can be converted to a JSON
244267 schema.
268+ tools
269+ The tools to use for the generation.
245270 **kwargs
246271 Additional keyword arguments to pass to the client.
247272
248273 Returns
249274 -------
250- Iterator[str ]
275+ Iterator[StreamingOutput ]
251276 An iterator that yields the text generated by the model.
252277
253278 """
279+ self .type_adapter .format_tools (tools )
280+
254281 if "model" not in kwargs and self .model_name is not None :
255282 kwargs ["model" ] = self .model_name
256283
@@ -260,8 +287,9 @@ def generate_stream(
260287 stream = True ,
261288 ** kwargs ,
262289 )
290+
263291 for chunk in response :
264- yield chunk .message .content
292+ yield StreamingOutput ( content = chunk .message .content )
265293
266294
267295class AsyncOllama (AsyncModel ):
@@ -290,9 +318,10 @@ def __init__(
290318
291319 async def generate (self ,
292320 model_input : Chat | str | list ,
293- output_type : Optional [Any ] = None ,
321+ output_type : Optional [Any ],
322+ tools : Optional [List [ToolDef ]],
294323 ** kwargs : Any ,
295- ) -> str :
324+ ) -> Output :
296325 """Generate text using Ollama.
297326
298327 Parameters
@@ -303,15 +332,19 @@ async def generate(self,
303332 The desired format of the response generated by the model. The
304333 output type must be of a type that can be converted to a JSON
305334 schema.
335+ tools
336+ The tools to use for the generation.
306337 **kwargs
307338 Additional keyword arguments to pass to the client.
308339
309340 Returns
310341 -------
311- str
342+ Output
312343 The text generated by the model.
313344
314345 """
346+ self .type_adapter .format_tools (tools )
347+
315348 if "model" not in kwargs and self .model_name is not None :
316349 kwargs ["model" ] = self .model_name
317350
@@ -320,12 +353,14 @@ async def generate(self,
320353 format = self .type_adapter .format_output_type (output_type ),
321354 ** kwargs ,
322355 )
323- return response .message .content
356+
357+ return Output (content = response .message .content )
324358
325359 async def generate_batch (
326360 self ,
327361 model_input ,
328- output_type = None ,
362+ output_type ,
363+ tools ,
329364 ** kwargs ,
330365 ):
331366 raise NotImplementedError (
@@ -335,9 +370,10 @@ async def generate_batch(
335370 async def generate_stream ( # type: ignore
336371 self ,
337372 model_input : Chat | str | list ,
338- output_type : Optional [Any ] = None ,
373+ output_type : Optional [Any ],
374+ tools : Optional [List [ToolDef ]],
339375 ** kwargs : Any ,
340- ) -> AsyncIterator [str ]:
376+ ) -> AsyncIterator [StreamingOutput ]:
341377 """Stream text using Ollama.
342378
343379 Parameters
@@ -348,15 +384,19 @@ async def generate_stream( # type: ignore
348384 The desired format of the response generated by the model. The
349385 output type must be of a type that can be converted to a JSON
350386 schema.
387+ tools
388+ The tools to use for the generation.
351389 **kwargs
352390 Additional keyword arguments to pass to the client.
353391
354392 Returns
355393 -------
356- Iterator[str ]
394+ Iterator[StreamingOutput ]
357395 An iterator that yields the text generated by the model.
358396
359397 """
398+ self .type_adapter .format_tools (tools )
399+
360400 if "model" not in kwargs and self .model_name is not None :
361401 kwargs ["model" ] = self .model_name
362402
@@ -366,8 +406,9 @@ async def generate_stream( # type: ignore
366406 stream = True ,
367407 ** kwargs ,
368408 )
409+
369410 async for chunk in stream :
370- yield chunk .message .content
411+ yield StreamingOutput ( content = chunk .message .content )
371412
372413
373414def from_ollama (
0 commit comments