-
Couldn't load subscription status.
- Fork 845
Passing Functions as Tools #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 29 commits
4383603
afe7db6
8fee892
0e5a940
1ef75a7
93c7a63
e5dc2b8
aa20015
d79538e
97aa167
8ec5123
efb775b
2efa54a
1f089f7
fe8d143
67321a8
2cc0b40
e68700c
f452fab
ca16670
7dcb598
7c5c294
16c868a
718412a
e7bb55f
7396ab6
0d9eec0
ed3ba8a
a4ec34a
6d9c156
c5c61a3
b0e0409
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| from __future__ import annotations | ||
| from collections import defaultdict | ||
| import inspect | ||
| from typing import Callable, Union | ||
|
|
||
| import pydantic | ||
| from ollama._types import Tool | ||
|
|
||
|
|
||
| def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | ||
| parsed_docstring = defaultdict(str) | ||
| if not doc_string: | ||
| return parsed_docstring | ||
|
|
||
| lowered_doc_string = doc_string.lower() | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| key = hash(doc_string) | ||
| parsed_docstring[key] = '' | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| for line in lowered_doc_string.splitlines(): | ||
| if line.startswith('args:'): | ||
| key = 'args' | ||
| elif line.startswith('returns:') or line.startswith('yields:') or line.startswith('raises:'): | ||
| key = '_' | ||
|
|
||
| else: | ||
| # maybe change to a list and join later | ||
| parsed_docstring[key] += f'{line.strip()}\n' | ||
|
|
||
| last_key = None | ||
| for line in parsed_docstring['args'].splitlines(): | ||
| line = line.strip() | ||
| if ':' in line and not line.startswith('args'): | ||
| # Split on first occurrence of '(' or ':' to separate arg name from description | ||
| split_char = '(' if '(' in line else ':' | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| arg_name, rest = line.split(split_char, 1) | ||
|
|
||
| last_key = arg_name.strip() | ||
| # Get description after the colon | ||
| arg_description = rest.split(':', 1)[1].strip() if split_char == '(' else rest.strip() | ||
| parsed_docstring[last_key] = arg_description | ||
|
|
||
| elif last_key and line: | ||
| parsed_docstring[last_key] += ' ' + line | ||
|
|
||
| return parsed_docstring | ||
|
|
||
|
|
||
| def convert_function_to_tool(func: Callable) -> Tool: | ||
| doc_string_hash = hash(inspect.getdoc(func)) | ||
| parsed_docstring = _parse_docstring(inspect.getdoc(func)) | ||
| schema = type( | ||
| func.__name__, | ||
| (pydantic.BaseModel,), | ||
| { | ||
| '__annotations__': {k: v.annotation if v.annotation != inspect._empty else str for k, v in inspect.signature(func).parameters.items()}, | ||
| '__signature__': inspect.signature(func), | ||
| '__doc__': parsed_docstring[doc_string_hash], | ||
| }, | ||
| ).model_json_schema() | ||
|
|
||
| for k, v in schema.get('properties', {}).items(): | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # If type is missing, the default is string | ||
| types = {t.get('type', 'string') for t in v.get('anyOf')} if 'anyOf' in v else {v.get('type', 'string')} | ||
| if 'null' in types: | ||
| schema['required'].remove(k) | ||
| types.discard('null') | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is okay, IMO something like: def (a:None, b:type(None)):
...is extremely unlikely |
||
|
|
||
| schema['properties'][k] = { | ||
| 'description': parsed_docstring[k], | ||
| 'type': ', '.join(types), | ||
| } | ||
|
|
||
| tool = Tool( | ||
| function=Tool.Function( | ||
| name=func.__name__, | ||
| description=schema.get('description', ''), | ||
| parameters=Tool.Function.Parameters(**schema), | ||
| ) | ||
| ) | ||
|
|
||
| return Tool.model_validate(tool) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,15 @@ | ||
| import os | ||
| import io | ||
| import json | ||
| from pydantic import ValidationError | ||
| import pytest | ||
| import tempfile | ||
| from pathlib import Path | ||
| from pytest_httpserver import HTTPServer, URIPattern | ||
| from werkzeug.wrappers import Request, Response | ||
| from PIL import Image | ||
|
|
||
| from ollama._client import Client, AsyncClient | ||
| from ollama._client import Client, AsyncClient, _copy_tools | ||
|
|
||
|
|
||
| class PrefixPattern(URIPattern): | ||
|
|
@@ -982,3 +983,67 @@ def test_headers(): | |
| ) | ||
| assert client._client.headers['x-custom'] == 'value' | ||
| assert client._client.headers['content-type'] == 'application/json' | ||
|
|
||
|
|
||
| def test_copy_tools(): | ||
| def func1(x: int) -> str: | ||
| """Simple function 1. | ||
| Args: | ||
| x (integer): A number | ||
| """ | ||
| pass | ||
|
|
||
| def func2(y: str) -> int: | ||
| """Simple function 2. | ||
| Args: | ||
| y (string): A string | ||
| """ | ||
| pass | ||
|
|
||
| # Test with list of functions | ||
| tools = list(_copy_tools([func1, func2])) | ||
| assert len(tools) == 2 | ||
| assert tools[0].function.name == 'func1' | ||
| assert tools[1].function.name == 'func2' | ||
|
|
||
| # Test with empty input | ||
| assert list(_copy_tools()) == [] | ||
| assert list(_copy_tools(None)) == [] | ||
| assert list(_copy_tools([])) == [] | ||
|
|
||
| # Test with mix of functions and tool dicts | ||
| tool_dict = { | ||
| 'type': 'function', | ||
| 'function': { | ||
| 'name': 'test', | ||
| 'description': 'Test function', | ||
| 'parameters': { | ||
| 'type': 'object', | ||
| 'properties': {'x': {'type': 'string', 'description': 'A string'}}, | ||
| 'required': ['x'], | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| tool_json = json.loads(json.dumps(tool_dict)) | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tools = list(_copy_tools([func1, tool_dict, tool_json])) | ||
| assert len(tools) == 3 | ||
| assert tools[0].function.name == 'func1' | ||
| assert tools[1].function.name == 'test' | ||
| assert tools[2].function.name == 'test' | ||
|
|
||
|
|
||
| def test_tool_validation(): | ||
| # Test that malformed tool dictionaries are rejected | ||
| # Raises ValidationError when used as it is a generator | ||
| with pytest.raises(ValidationError): | ||
| invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} | ||
| list(_copy_tools([invalid_tool])) | ||
|
|
||
| # Test missing required fields | ||
| incomplete_tool = { | ||
| 'type': 'function', | ||
| 'function': {'name': 'test'}, # missing description and parameters | ||
|
||
| } | ||
| with pytest.raises(ValidationError): | ||
| list(_copy_tools([incomplete_tool])) | ||
Uh oh!
There was an error while loading. Please reload this page.