Skip to content

Commit 34e98bd

Browse files
authored
types: relax type for tools (#550)
1 parent dad9e1c commit 34e98bd

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

ollama/_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def __contains__(self, key: str) -> bool:
7979
if key in self.model_fields_set:
8080
return True
8181

82-
if value := self.model_fields.get(key):
82+
if value := self.__class__.model_fields.get(key):
8383
return value.default is not None
8484

8585
return False
@@ -313,7 +313,7 @@ class Function(SubscriptableBaseModel):
313313

314314

315315
class Tool(SubscriptableBaseModel):
316-
type: Optional[Literal['function']] = 'function'
316+
type: Optional[str] = 'function'
317317

318318
class Function(SubscriptableBaseModel):
319319
name: Optional[str] = None

ollama/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ def convert_function_to_tool(func: Callable) -> Tool:
7979
}
8080

8181
tool = Tool(
82+
type='function',
8283
function=Tool.Function(
8384
name=func.__name__,
8485
description=schema.get('description', ''),
8586
parameters=Tool.Function.Parameters(**schema),
86-
)
87+
),
8788
)
8889

8990
return Tool.model_validate(tool)

tests/test_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010
from httpx import Response as httpxResponse
11-
from pydantic import BaseModel, ValidationError
11+
from pydantic import BaseModel
1212
from pytest_httpserver import HTTPServer, URIPattern
1313
from werkzeug.wrappers import Request, Response
1414

@@ -1136,10 +1136,11 @@ def func2(y: str) -> int:
11361136

11371137

11381138
def test_tool_validation():
1139-
# Raises ValidationError when used as it is a generator
1140-
with pytest.raises(ValidationError):
1141-
invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}}
1142-
list(_copy_tools([invalid_tool]))
1139+
arbitrary_tool = {'type': 'custom_type', 'function': {'name': 'test'}}
1140+
tools = list(_copy_tools([arbitrary_tool]))
1141+
assert len(tools) == 1
1142+
assert tools[0].type == 'custom_type'
1143+
assert tools[0].function.name == 'test'
11431144

11441145

11451146
def test_client_connection_error():

0 commit comments

Comments
 (0)