Skip to content

Commit 359c63d

Browse files
authored
integrate tool calls (#213)
1 parent 1a15742 commit 359c63d

File tree

5 files changed

+155
-14
lines changed

5 files changed

+155
-14
lines changed

examples/tools/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# tools
2+
3+
This example demonstrates how to utilize tool calls with an asynchronous Ollama client and the chat endpoint.

examples/tools/main.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import json
2+
import ollama
3+
import asyncio
4+
5+
6+
# Simulates an API call to get flight times
7+
# In a real application, this would fetch data from a live database or API
8+
def get_flight_times(departure: str, arrival: str) -> str:
9+
flights = {
10+
'NYC-LAX': {'departure': '08:00 AM', 'arrival': '11:30 AM', 'duration': '5h 30m'},
11+
'LAX-NYC': {'departure': '02:00 PM', 'arrival': '10:30 PM', 'duration': '5h 30m'},
12+
'LHR-JFK': {'departure': '10:00 AM', 'arrival': '01:00 PM', 'duration': '8h 00m'},
13+
'JFK-LHR': {'departure': '09:00 PM', 'arrival': '09:00 AM', 'duration': '7h 00m'},
14+
'CDG-DXB': {'departure': '11:00 AM', 'arrival': '08:00 PM', 'duration': '6h 00m'},
15+
'DXB-CDG': {'departure': '03:00 AM', 'arrival': '07:30 AM', 'duration': '7h 30m'},
16+
}
17+
18+
key = f'{departure}-{arrival}'.upper()
19+
return json.dumps(flights.get(key, {'error': 'Flight not found'}))
20+
21+
22+
async def run(model: str):
23+
client = ollama.AsyncClient()
24+
# Initialize conversation with a user query
25+
messages = [{'role': 'user', 'content': 'What is the flight time from New York (NYC) to Los Angeles (LAX)?'}]
26+
27+
# First API call: Send the query and function description to the model
28+
response = await client.chat(
29+
model=model,
30+
messages=messages,
31+
tools=[
32+
{
33+
'type': 'function',
34+
'function': {
35+
'name': 'get_flight_times',
36+
'description': 'Get the flight times between two cities',
37+
'parameters': {
38+
'type': 'object',
39+
'properties': {
40+
'departure': {
41+
'type': 'string',
42+
'description': 'The departure city (airport code)',
43+
},
44+
'arrival': {
45+
'type': 'string',
46+
'description': 'The arrival city (airport code)',
47+
},
48+
},
49+
'required': ['departure', 'arrival'],
50+
},
51+
},
52+
},
53+
],
54+
)
55+
56+
# Add the model's response to the conversation history
57+
messages.append(response['message'])
58+
59+
# Check if the model decided to use the provided function
60+
if not response['message'].get('tool_calls'):
61+
print("The model didn't use the function. Its response was:")
62+
print(response['message']['content'])
63+
return
64+
65+
# Process function calls made by the model
66+
if response['message'].get('tool_calls'):
67+
available_functions = {
68+
'get_flight_times': get_flight_times,
69+
}
70+
for tool in response['message']['tool_calls']:
71+
function_to_call = available_functions[tool['function']['name']]
72+
function_response = function_to_call(tool['function']['arguments']['departure'], tool['function']['arguments']['arrival'])
73+
# Add function response to the conversation
74+
messages.append(
75+
{
76+
'role': 'tool',
77+
'content': function_response,
78+
}
79+
)
80+
81+
# Second API call: Get final response from the model
82+
final_response = await client.chat(model=model, messages=messages)
83+
print(final_response['message']['content'])
84+
85+
86+
# Run the async function
87+
asyncio.run(run('mistral'))

ollama/_client.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
except metadata.PackageNotFoundError:
2828
__version__ = '0.0.0'
2929

30-
from ollama._types import Message, Options, RequestError, ResponseError
30+
from ollama._types import Message, Options, RequestError, ResponseError, Tool
3131

3232

3333
class BaseClient:
@@ -180,6 +180,7 @@ def chat(
180180
self,
181181
model: str = '',
182182
messages: Optional[Sequence[Message]] = None,
183+
tools: Optional[Sequence[Tool]] = None,
183184
stream: Literal[False] = False,
184185
format: Literal['', 'json'] = '',
185186
options: Optional[Options] = None,
@@ -191,6 +192,7 @@ def chat(
191192
self,
192193
model: str = '',
193194
messages: Optional[Sequence[Message]] = None,
195+
tools: Optional[Sequence[Tool]] = None,
194196
stream: Literal[True] = True,
195197
format: Literal['', 'json'] = '',
196198
options: Optional[Options] = None,
@@ -201,6 +203,7 @@ def chat(
201203
self,
202204
model: str = '',
203205
messages: Optional[Sequence[Message]] = None,
206+
tools: Optional[Sequence[Tool]] = None,
204207
stream: bool = False,
205208
format: Literal['', 'json'] = '',
206209
options: Optional[Options] = None,
@@ -222,12 +225,6 @@ def chat(
222225
messages = deepcopy(messages)
223226

224227
for message in messages or []:
225-
if not isinstance(message, dict):
226-
raise TypeError('messages must be a list of Message or dict-like objects')
227-
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
228-
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
229-
if 'content' not in message:
230-
raise RequestError('messages must contain content')
231228
if images := message.get('images'):
232229
message['images'] = [_encode_image(image) for image in images]
233230

@@ -237,6 +234,7 @@ def chat(
237234
json={
238235
'model': model,
239236
'messages': messages,
237+
'tools': tools or [],
240238
'stream': stream,
241239
'format': format,
242240
'options': options or {},
@@ -574,6 +572,7 @@ async def chat(
574572
self,
575573
model: str = '',
576574
messages: Optional[Sequence[Message]] = None,
575+
tools: Optional[Sequence[Tool]] = None,
577576
stream: Literal[False] = False,
578577
format: Literal['', 'json'] = '',
579578
options: Optional[Options] = None,
@@ -585,6 +584,7 @@ async def chat(
585584
self,
586585
model: str = '',
587586
messages: Optional[Sequence[Message]] = None,
587+
tools: Optional[Sequence[Tool]] = None,
588588
stream: Literal[True] = True,
589589
format: Literal['', 'json'] = '',
590590
options: Optional[Options] = None,
@@ -595,6 +595,7 @@ async def chat(
595595
self,
596596
model: str = '',
597597
messages: Optional[Sequence[Message]] = None,
598+
tools: Optional[Sequence[Tool]] = None,
598599
stream: bool = False,
599600
format: Literal['', 'json'] = '',
600601
options: Optional[Options] = None,
@@ -615,12 +616,6 @@ async def chat(
615616
messages = deepcopy(messages)
616617

617618
for message in messages or []:
618-
if not isinstance(message, dict):
619-
raise TypeError('messages must be a list of strings')
620-
if not (role := message.get('role')) or role not in ['system', 'user', 'assistant']:
621-
raise RequestError('messages must contain a role and it must be one of "system", "user", or "assistant"')
622-
if 'content' not in message:
623-
raise RequestError('messages must contain content')
624619
if images := message.get('images'):
625620
message['images'] = [_encode_image(image) for image in images]
626621

@@ -630,6 +625,7 @@ async def chat(
630625
json={
631626
'model': model,
632627
'messages': messages,
628+
'tools': tools or [],
633629
'stream': stream,
634630
'format': format,
635631
'options': options or {},

ollama/_types.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, TypedDict, Sequence, Literal
2+
from typing import Any, TypedDict, Sequence, Literal, Mapping
33

44
import sys
55

@@ -53,6 +53,27 @@ class GenerateResponse(BaseGenerateResponse):
5353
'Tokenized history up to the point of the response.'
5454

5555

56+
class ToolCallFunction(TypedDict):
57+
"""
58+
Tool call function.
59+
"""
60+
61+
name: str
62+
'Name of the function.'
63+
64+
args: NotRequired[Mapping[str, Any]]
65+
'Arguments of the function.'
66+
67+
68+
class ToolCall(TypedDict):
69+
"""
70+
Model tool calls.
71+
"""
72+
73+
function: ToolCallFunction
74+
'Function to be called.'
75+
76+
5677
class Message(TypedDict):
5778
"""
5879
Chat message.
@@ -76,6 +97,34 @@ class Message(TypedDict):
7697
Valid image formats depend on the model. See the model card for more information.
7798
"""
7899

100+
tool_calls: NotRequired[Sequence[ToolCall]]
101+
"""
102+
Tools calls to be made by the model.
103+
"""
104+
105+
106+
class Property(TypedDict):
107+
type: str
108+
description: str
109+
enum: NotRequired[Sequence[str]] # `enum` is optional and can be a list of strings
110+
111+
112+
class Parameters(TypedDict):
113+
type: str
114+
required: Sequence[str]
115+
properties: Mapping[str, Property]
116+
117+
118+
class ToolFunction(TypedDict):
119+
name: str
120+
description: str
121+
parameters: Parameters
122+
123+
124+
class Tool(TypedDict):
125+
type: str
126+
function: ToolFunction
127+
79128

80129
class ChatResponse(BaseGenerateResponse):
81130
"""

tests/test_client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_client_chat(httpserver: HTTPServer):
2626
json={
2727
'model': 'dummy',
2828
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
29+
'tools': [],
2930
'stream': False,
3031
'format': '',
3132
'options': {},
@@ -73,6 +74,7 @@ def generate():
7374
json={
7475
'model': 'dummy',
7576
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
77+
'tools': [],
7678
'stream': True,
7779
'format': '',
7880
'options': {},
@@ -102,6 +104,7 @@ def test_client_chat_images(httpserver: HTTPServer):
102104
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
103105
},
104106
],
107+
'tools': [],
105108
'stream': False,
106109
'format': '',
107110
'options': {},
@@ -522,6 +525,7 @@ async def test_async_client_chat(httpserver: HTTPServer):
522525
json={
523526
'model': 'dummy',
524527
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
528+
'tools': [],
525529
'stream': False,
526530
'format': '',
527531
'options': {},
@@ -560,6 +564,7 @@ def generate():
560564
json={
561565
'model': 'dummy',
562566
'messages': [{'role': 'user', 'content': 'Why is the sky blue?'}],
567+
'tools': [],
563568
'stream': True,
564569
'format': '',
565570
'options': {},
@@ -590,6 +595,7 @@ async def test_async_client_chat_images(httpserver: HTTPServer):
590595
'images': ['iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAIAAACQd1PeAAAADElEQVR4nGNgYGAAAAAEAAH2FzhVAAAAAElFTkSuQmCC'],
591596
},
592597
],
598+
'tools': [],
593599
'stream': False,
594600
'format': '',
595601
'options': {},

0 commit comments

Comments
 (0)