Skip to content

Commit 406f5a3

Browse files
committed
Implement tools and outputs for the Ollama model
1 parent 7f748f1 commit 406f5a3

File tree

3 files changed

+108
-53
lines changed

3 files changed

+108
-53
lines changed

outlines/models/ollama.py

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,22 @@
22

33
import json
44
from functools import singledispatchmethod
5-
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, Optional, Union
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
AsyncIterator,
9+
Iterator,
10+
List,
11+
Optional,
12+
Union,
13+
)
614

715
from pydantic import TypeAdapter
816

917
from outlines.inputs import Chat, Image
1018
from outlines.models.base import AsyncModel, Model, ModelTypeAdapter
19+
from outlines.outputs import Output, StreamingOutput
20+
from outlines.tools import ToolDef
1121
from outlines.types import CFG, JsonSchema, Regex
1222
from outlines.types.utils import (
1323
is_dataclass,
@@ -74,7 +84,7 @@ def format_chat_model_input(self, model_input: Chat) -> list:
7484
7585
"""
7686
return [
77-
self._create_message(message["role"], message["content"])
87+
self._create_message(message["role"], message["content"]) # type: ignore
7888
for message in model_input.messages
7989
]
8090

@@ -107,9 +117,7 @@ def _create_message(self, role: str, content: str | list) -> dict:
107117
"and a list of images."
108118
)
109119

110-
def format_output_type(
111-
self, output_type: Optional[Any] = None
112-
) -> Optional[str]:
120+
def format_output_type(self, output_type: Optional[Any]) -> Optional[str]:
113121
"""Format the output type to pass to the client.
114122
115123
TODO: `int`, `float` and other Python types could be supported via
@@ -159,6 +167,13 @@ def format_output_type(
159167
"Consider using a local model instead."
160168
)
161169

170+
def format_tools(self, tools):
171+
"""Not available for Ollama."""
172+
if tools:
173+
raise NotImplementedError(
174+
"Tools are not available for Ollama."
175+
)
176+
162177

163178
class Ollama(Model):
164179
"""Thin wrapper around the `ollama.Client` client.
@@ -184,9 +199,10 @@ def __init__(self, client: "Client", model_name: Optional[str] = None):
184199

185200
def generate(self,
186201
model_input: Chat | str | list,
187-
output_type: Optional[Any] = None,
202+
output_type: Optional[Any],
203+
tools: Optional[List[ToolDef]],
188204
**kwargs: Any,
189-
) -> str:
205+
) -> Output:
190206
"""Generate text using Ollama.
191207
192208
Parameters
@@ -197,15 +213,19 @@ def generate(self,
197213
The desired format of the response generated by the model. The
198214
output type must be of a type that can be converted to a JSON
199215
schema.
216+
tools
217+
The tools to use for the generation.
200218
**kwargs
201219
Additional keyword arguments to pass to the client.
202220
203221
Returns
204222
-------
205-
str
223+
Output
206224
The text generated by the model.
207225
208226
"""
227+
self.type_adapter.format_tools(tools)
228+
209229
if "model" not in kwargs and self.model_name is not None:
210230
kwargs["model"] = self.model_name
211231

@@ -214,12 +234,14 @@ def generate(self,
214234
format=self.type_adapter.format_output_type(output_type),
215235
**kwargs,
216236
)
217-
return response.message.content
237+
238+
return Output(content=response.message.content)
218239

219240
def generate_batch(
220241
self,
221242
model_input,
222-
output_type = None,
243+
output_type,
244+
tools,
223245
**kwargs,
224246
):
225247
raise NotImplementedError(
@@ -229,9 +251,10 @@ def generate_batch(
229251
def generate_stream(
230252
self,
231253
model_input: Chat | str | list,
232-
output_type: Optional[Any] = None,
254+
output_type: Optional[Any],
255+
tools: Optional[List[ToolDef]],
233256
**kwargs: Any,
234-
) -> Iterator[str]:
257+
) -> Iterator[StreamingOutput]:
235258
"""Stream text using Ollama.
236259
237260
Parameters
@@ -242,15 +265,19 @@ def generate_stream(
242265
The desired format of the response generated by the model. The
243266
output type must be of a type that can be converted to a JSON
244267
schema.
268+
tools
269+
The tools to use for the generation.
245270
**kwargs
246271
Additional keyword arguments to pass to the client.
247272
248273
Returns
249274
-------
250-
Iterator[str]
275+
Iterator[StreamingOutput]
251276
An iterator that yields the text generated by the model.
252277
253278
"""
279+
self.type_adapter.format_tools(tools)
280+
254281
if "model" not in kwargs and self.model_name is not None:
255282
kwargs["model"] = self.model_name
256283

@@ -260,8 +287,9 @@ def generate_stream(
260287
stream=True,
261288
**kwargs,
262289
)
290+
263291
for chunk in response:
264-
yield chunk.message.content
292+
yield StreamingOutput(content=chunk.message.content)
265293

266294

267295
class AsyncOllama(AsyncModel):
@@ -290,9 +318,10 @@ def __init__(
290318

291319
async def generate(self,
292320
model_input: Chat | str | list,
293-
output_type: Optional[Any] = None,
321+
output_type: Optional[Any],
322+
tools: Optional[List[ToolDef]],
294323
**kwargs: Any,
295-
) -> str:
324+
) -> Output:
296325
"""Generate text using Ollama.
297326
298327
Parameters
@@ -303,15 +332,19 @@ async def generate(self,
303332
The desired format of the response generated by the model. The
304333
output type must be of a type that can be converted to a JSON
305334
schema.
335+
tools
336+
The tools to use for the generation.
306337
**kwargs
307338
Additional keyword arguments to pass to the client.
308339
309340
Returns
310341
-------
311-
str
342+
Output
312343
The text generated by the model.
313344
314345
"""
346+
self.type_adapter.format_tools(tools)
347+
315348
if "model" not in kwargs and self.model_name is not None:
316349
kwargs["model"] = self.model_name
317350

@@ -320,12 +353,14 @@ async def generate(self,
320353
format=self.type_adapter.format_output_type(output_type),
321354
**kwargs,
322355
)
323-
return response.message.content
356+
357+
return Output(content=response.message.content)
324358

325359
async def generate_batch(
326360
self,
327361
model_input,
328-
output_type = None,
362+
output_type,
363+
tools,
329364
**kwargs,
330365
):
331366
raise NotImplementedError(
@@ -335,9 +370,10 @@ async def generate_batch(
335370
async def generate_stream( # type: ignore
336371
self,
337372
model_input: Chat | str | list,
338-
output_type: Optional[Any] = None,
373+
output_type: Optional[Any],
374+
tools: Optional[List[ToolDef]],
339375
**kwargs: Any,
340-
) -> AsyncIterator[str]:
376+
) -> AsyncIterator[StreamingOutput]:
341377
"""Stream text using Ollama.
342378
343379
Parameters
@@ -348,15 +384,19 @@ async def generate_stream( # type: ignore
348384
The desired format of the response generated by the model. The
349385
output type must be of a type that can be converted to a JSON
350386
schema.
387+
tools
388+
The tools to use for the generation.
351389
**kwargs
352390
Additional keyword arguments to pass to the client.
353391
354392
Returns
355393
-------
356-
Iterator[str]
394+
Iterator[StreamingOutput]
357395
An iterator that yields the text generated by the model.
358396
359397
"""
398+
self.type_adapter.format_tools(tools)
399+
360400
if "model" not in kwargs and self.model_name is not None:
361401
kwargs["model"] = self.model_name
362402

@@ -366,8 +406,9 @@ async def generate_stream( # type: ignore
366406
stream=True,
367407
**kwargs,
368408
)
409+
369410
async for chunk in stream:
370-
yield chunk.message.content
411+
yield StreamingOutput(content=chunk.message.content)
371412

372413

373414
def from_ollama(

0 commit comments

Comments
 (0)