Skip to content

Commit 139c89e

Browse files
authored
Passing Functions as Tools (#321)
* Functions can now be passed as tools
1 parent da2893b commit 139c89e

File tree

6 files changed

+545
-39
lines changed

6 files changed

+545
-39
lines changed

ollama/_client.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from typing import (
1212
Any,
13+
Callable,
1314
Literal,
1415
Mapping,
1516
Optional,
@@ -22,6 +23,9 @@
2223

2324
import sys
2425

26+
27+
from ollama._utils import convert_function_to_tool
28+
2529
if sys.version_info < (3, 9):
2630
from typing import Iterator, AsyncIterator
2731
else:
@@ -284,7 +288,7 @@ def chat(
284288
model: str = '',
285289
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
286290
*,
287-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
291+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
288292
stream: bool = False,
289293
format: Optional[Literal['', 'json']] = None,
290294
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -293,6 +297,30 @@ def chat(
293297
"""
294298
Create a chat response using the requested model.
295299
300+
Args:
301+
tools:
302+
A JSON schema as a dict, an Ollama Tool or a Python Function.
303+
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
304+
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
305+
stream: Whether to stream the response.
306+
format: The format of the response.
307+
308+
Example:
309+
def add_two_numbers(a: int, b: int) -> int:
310+
'''
311+
Add two numbers together.
312+
313+
Args:
314+
a: First number to add
315+
b: Second number to add
316+
317+
Returns:
318+
int: The sum of a and b
319+
'''
320+
return a + b
321+
322+
client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
323+
296324
Raises `RequestError` if a model is not provided.
297325
298326
Raises `ResponseError` if the request could not be fulfilled.
@@ -750,7 +778,7 @@ async def chat(
750778
model: str = '',
751779
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
752780
*,
753-
tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]] = None,
781+
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
754782
stream: Literal[True] = True,
755783
format: Optional[Literal['', 'json']] = None,
756784
options: Optional[Union[Mapping[str, Any], Options]] = None,
@@ -771,6 +799,30 @@ async def chat(
771799
"""
772800
Create a chat response using the requested model.
773801
802+
Args:
803+
tools:
804+
A JSON schema as a dict, an Ollama Tool or a Python Function.
805+
Python functions need to follow Google style docstrings to be converted to an Ollama Tool.
806+
For more information, see: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
807+
stream: Whether to stream the response.
808+
format: The format of the response.
809+
810+
Example:
811+
def add_two_numbers(a: int, b: int) -> int:
812+
'''
813+
Add two numbers together.
814+
815+
Args:
816+
a: First number to add
817+
b: Second number to add
818+
819+
Returns:
820+
int: The sum of a and b
821+
'''
822+
return a + b
823+
824+
await client.chat(model='llama3.1:8b', tools=[add_two_numbers], messages=[...])
825+
774826
Raises `RequestError` if a model is not provided.
775827
776828
Raises `ResponseError` if the request could not be fulfilled.
@@ -1075,9 +1127,9 @@ def _copy_messages(messages: Optional[Sequence[Union[Mapping[str, Any], Message]
10751127
)
10761128

10771129

1078-
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool]]]) -> Iterator[Tool]:
1079-
for tool in tools or []:
1080-
yield Tool.model_validate(tool)
1130+
def _copy_tools(tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None) -> Iterator[Tool]:
1131+
for unprocessed_tool in tools or []:
1132+
yield convert_function_to_tool(unprocessed_tool) if callable(unprocessed_tool) else Tool.model_validate(unprocessed_tool)
10811133

10821134

10831135
def _as_path(s: Optional[Union[str, PathLike]]) -> Union[Path, None]:

ollama/_types.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
11
import json
2-
from base64 import b64encode
2+
from base64 import b64decode, b64encode
33
from pathlib import Path
44
from datetime import datetime
5-
from typing import (
6-
Any,
7-
Literal,
8-
Mapping,
9-
Optional,
10-
Sequence,
11-
Union,
12-
)
13-
from typing_extensions import Annotated
5+
from typing import Any, Mapping, Optional, Union, Sequence
6+
7+
from typing_extensions import Annotated, Literal
148

159
from pydantic import (
1610
BaseModel,
1711
ByteSize,
12+
ConfigDict,
1813
Field,
19-
FilePath,
20-
Base64Str,
2114
model_serializer,
2215
)
23-
from pydantic.json_schema import JsonSchemaValue
2416

2517

2618
class SubscriptableBaseModel(BaseModel):
@@ -95,16 +87,26 @@ class BaseGenerateRequest(BaseStreamableRequest):
9587

9688

9789
class Image(BaseModel):
98-
value: Union[FilePath, Base64Str, bytes]
90+
value: Union[str, bytes, Path]
9991

100-
# This overloads the `model_dump` method and returns values depending on the type of the `value` field
10192
@model_serializer
10293
def serialize_model(self):
103-
if isinstance(self.value, Path):
104-
return b64encode(self.value.read_bytes()).decode()
105-
elif isinstance(self.value, bytes):
106-
return b64encode(self.value).decode()
107-
return self.value
94+
if isinstance(self.value, (Path, bytes)):
95+
return b64encode(self.value.read_bytes() if isinstance(self.value, Path) else self.value).decode()
96+
97+
if isinstance(self.value, str):
98+
if Path(self.value).exists():
99+
return b64encode(Path(self.value).read_bytes()).decode()
100+
101+
if self.value.split('.')[-1] in ('png', 'jpg', 'jpeg', 'webp'):
102+
raise ValueError(f'File {self.value} does not exist')
103+
104+
try:
105+
# Try to decode to check if it's already base64
106+
b64decode(self.value)
107+
return self.value
108+
except Exception:
109+
raise ValueError('Invalid image data, expected base64 string or path to image file') from Exception
108110

109111

110112
class GenerateRequest(BaseGenerateRequest):
@@ -222,20 +224,27 @@ class Function(SubscriptableBaseModel):
222224

223225

224226
class Tool(SubscriptableBaseModel):
225-
type: Literal['function'] = 'function'
227+
type: Optional[Literal['function']] = 'function'
226228

227229
class Function(SubscriptableBaseModel):
228-
name: str
229-
description: str
230+
name: Optional[str] = None
231+
description: Optional[str] = None
230232

231233
class Parameters(SubscriptableBaseModel):
232-
type: str
234+
type: Optional[Literal['object']] = 'object'
233235
required: Optional[Sequence[str]] = None
234-
properties: Optional[JsonSchemaValue] = None
235236

236-
parameters: Parameters
237+
class Property(SubscriptableBaseModel):
238+
model_config = ConfigDict(arbitrary_types_allowed=True)
239+
240+
type: Optional[str] = None
241+
description: Optional[str] = None
242+
243+
properties: Optional[Mapping[str, Property]] = None
237244

238-
function: Function
245+
parameters: Optional[Parameters] = None
246+
247+
function: Optional[Function] = None
239248

240249

241250
class ChatRequest(BaseGenerateRequest):
@@ -335,6 +344,7 @@ class ModelDetails(SubscriptableBaseModel):
335344

336345
class ListResponse(SubscriptableBaseModel):
337346
class Model(SubscriptableBaseModel):
347+
model: Optional[str] = None
338348
modified_at: Optional[datetime] = None
339349
digest: Optional[str] = None
340350
size: Optional[ByteSize] = None

ollama/_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
from collections import defaultdict
3+
import inspect
4+
from typing import Callable, Union
5+
import re
6+
7+
import pydantic
8+
from ollama._types import Tool
9+
10+
11+
def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]:
12+
parsed_docstring = defaultdict(str)
13+
if not doc_string:
14+
return parsed_docstring
15+
16+
key = hash(doc_string)
17+
for line in doc_string.splitlines():
18+
lowered_line = line.lower().strip()
19+
if lowered_line.startswith('args:'):
20+
key = 'args'
21+
elif lowered_line.startswith('returns:') or lowered_line.startswith('yields:') or lowered_line.startswith('raises:'):
22+
key = '_'
23+
24+
else:
25+
# maybe change to a list and join later
26+
parsed_docstring[key] += f'{line.strip()}\n'
27+
28+
last_key = None
29+
for line in parsed_docstring['args'].splitlines():
30+
line = line.strip()
31+
if ':' in line:
32+
# Split the line on either:
33+
# 1. A parenthetical expression like (integer) - captured in group 1
34+
# 2. A colon :
35+
# Followed by optional whitespace. Only split on first occurrence.
36+
parts = re.split(r'(?:\(([^)]*)\)|:)\s*', line, maxsplit=1)
37+
38+
arg_name = parts[0].strip()
39+
last_key = arg_name
40+
41+
# Get the description - will be in parts[1] if parenthetical or parts[-1] if after colon
42+
arg_description = parts[-1].strip()
43+
if len(parts) > 2 and parts[1]: # Has parenthetical content
44+
arg_description = parts[-1].split(':', 1)[-1].strip()
45+
46+
parsed_docstring[last_key] = arg_description
47+
48+
elif last_key and line:
49+
parsed_docstring[last_key] += ' ' + line
50+
51+
return parsed_docstring
52+
53+
54+
def convert_function_to_tool(func: Callable) -> Tool:
55+
doc_string_hash = hash(inspect.getdoc(func))
56+
parsed_docstring = _parse_docstring(inspect.getdoc(func))
57+
schema = type(
58+
func.__name__,
59+
(pydantic.BaseModel,),
60+
{
61+
'__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()},
62+
'__signature__': inspect.signature(func),
63+
'__doc__': parsed_docstring[doc_string_hash],
64+
},
65+
).model_json_schema()
66+
67+
for k, v in schema.get('properties', {}).items():
68+
# If type is missing, the default is string
69+
types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')}
70+
if 'null' in types:
71+
schema['required'].remove(k)
72+
types.discard('null')
73+
74+
schema['properties'][k] = {
75+
'description': parsed_docstring[k],
76+
'type': ', '.join(types),
77+
}
78+
79+
tool = Tool(
80+
function=Tool.Function(
81+
name=func.__name__,
82+
description=schema.get('description', ''),
83+
parameters=Tool.Function.Parameters(**schema),
84+
)
85+
)
86+
87+
return Tool.model_validate(tool)

tests/test_client.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import os
22
import io
33
import json
4+
from pydantic import ValidationError
45
import pytest
56
import tempfile
67
from pathlib import Path
78
from pytest_httpserver import HTTPServer, URIPattern
89
from werkzeug.wrappers import Request, Response
910
from PIL import Image
1011

11-
from ollama._client import Client, AsyncClient
12+
from ollama._client import Client, AsyncClient, _copy_tools
1213

1314

1415
class PrefixPattern(URIPattern):
@@ -982,3 +983,56 @@ def test_headers():
982983
)
983984
assert client._client.headers['x-custom'] == 'value'
984985
assert client._client.headers['content-type'] == 'application/json'
986+
987+
988+
def test_copy_tools():
989+
def func1(x: int) -> str:
990+
"""Simple function 1.
991+
Args:
992+
x (integer): A number
993+
"""
994+
pass
995+
996+
def func2(y: str) -> int:
997+
"""Simple function 2.
998+
Args:
999+
y (string): A string
1000+
"""
1001+
pass
1002+
1003+
# Test with list of functions
1004+
tools = list(_copy_tools([func1, func2]))
1005+
assert len(tools) == 2
1006+
assert tools[0].function.name == 'func1'
1007+
assert tools[1].function.name == 'func2'
1008+
1009+
# Test with empty input
1010+
assert list(_copy_tools()) == []
1011+
assert list(_copy_tools(None)) == []
1012+
assert list(_copy_tools([])) == []
1013+
1014+
# Test with mix of functions and tool dicts
1015+
tool_dict = {
1016+
'type': 'function',
1017+
'function': {
1018+
'name': 'test',
1019+
'description': 'Test function',
1020+
'parameters': {
1021+
'type': 'object',
1022+
'properties': {'x': {'type': 'string', 'description': 'A string'}},
1023+
'required': ['x'],
1024+
},
1025+
},
1026+
}
1027+
1028+
tools = list(_copy_tools([func1, tool_dict]))
1029+
assert len(tools) == 2
1030+
assert tools[0].function.name == 'func1'
1031+
assert tools[1].function.name == 'test'
1032+
1033+
1034+
def test_tool_validation():
1035+
# Raises ValidationError when used as it is a generator
1036+
with pytest.raises(ValidationError):
1037+
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
1038+
list(_copy_tools([invalid_tool]))

0 commit comments

Comments
 (0)