-
Couldn't load subscription status.
- Fork 845
Passing Functions as Tools #321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
4383603
afe7db6
8fee892
0e5a940
1ef75a7
93c7a63
e5dc2b8
aa20015
d79538e
97aa167
8ec5123
efb775b
2efa54a
1f089f7
fe8d143
67321a8
2cc0b40
e68700c
f452fab
ca16670
7dcb598
7c5c294
16c868a
718412a
e7bb55f
7396ab6
0d9eec0
ed3ba8a
a4ec34a
6d9c156
c5c61a3
b0e0409
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| from __future__ import annotations | ||
| import inspect | ||
| from typing import Callable, Union | ||
|
|
||
| import pydantic | ||
| from ollama._types import Tool | ||
|
|
||
|
|
||
| def _parse_docstring(doc_string: Union[str, None]) -> dict[str, str]: | ||
| parsed_docstring = {'description': ''} | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not doc_string: | ||
| return parsed_docstring | ||
|
|
||
| lowered_doc_string = doc_string.lower() | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if 'args:' not in lowered_doc_string: | ||
| parsed_docstring['description'] = lowered_doc_string.strip() | ||
| return parsed_docstring | ||
|
|
||
| else: | ||
| parsed_docstring['description'] = lowered_doc_string.split('args:')[0].strip() | ||
| args_section = lowered_doc_string.split('args:')[1] | ||
|
|
||
| if 'returns:' in lowered_doc_string: | ||
| # Return section can be captured and used | ||
| args_section = args_section.split('returns:')[0] | ||
|
|
||
| if 'yields:' in lowered_doc_string: | ||
| args_section = args_section.split('yields:')[0] | ||
|
|
||
| cur_var = None | ||
| for line in args_section.split('\n'): | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| line = line.strip() | ||
| if not line: | ||
| continue | ||
| if ':' not in line: | ||
| # Continuation of the previous parameter's description | ||
| if cur_var: | ||
| parsed_docstring[cur_var] += f' {line}' | ||
| continue | ||
|
|
||
| # For the case with: `param_name (type)`: ... | ||
| if '(' in line: | ||
| param_name = line.split('(')[0] | ||
| param_desc = line.split('):')[1] | ||
|
|
||
| # For the case with: `param_name: ...` | ||
| else: | ||
| param_name, param_desc = line.split(':', 1) | ||
|
|
||
| parsed_docstring[param_name.strip()] = param_desc.strip() | ||
| cur_var = param_name.strip() | ||
|
|
||
| return parsed_docstring | ||
|
|
||
|
|
||
| def convert_function_to_tool(func: Callable) -> Tool: | ||
| schema = type( | ||
| func.__name__, | ||
| (pydantic.BaseModel,), | ||
| { | ||
| '__annotations__': {k: v.annotation for k, v in inspect.signature(func).parameters.items()}, | ||
| '__signature__': inspect.signature(func), | ||
| '__doc__': inspect.getdoc(func), | ||
| }, | ||
| ).model_json_schema() | ||
|
|
||
| properties = {} | ||
| required = [] | ||
| parsed_docstring = _parse_docstring(schema.get('description')) | ||
| for k, v in schema.get('properties', {}).items(): | ||
ParthSareen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| prop = { | ||
| 'description': parsed_docstring.get(k, ''), | ||
| 'type': v.get('type'), | ||
| } | ||
|
|
||
| if 'anyOf' in v: | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| is_optional = any(t.get('type') == 'null' for t in v['anyOf']) | ||
| types = [t.get('type', 'string') for t in v['anyOf'] if t.get('type') != 'null'] | ||
| prop['type'] = types[0] if len(types) == 1 else str(types) | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not is_optional: | ||
| required.append(k) | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| if prop['type'] != 'null': | ||
| required.append(k) | ||
|
|
||
| properties[k] = prop | ||
|
|
||
| schema['properties'] = properties | ||
|
|
||
| tool = Tool( | ||
| function=Tool.Function( | ||
| name=func.__name__, | ||
| description=parsed_docstring.get('description'), | ||
| parameters=Tool.Function.Parameters( | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| type='object', | ||
| properties=schema.get('properties', {}), | ||
| required=required, | ||
| ), | ||
| ) | ||
| ) | ||
|
|
||
| return Tool.model_validate(tool) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,15 @@ | ||
| import os | ||
| import io | ||
| import json | ||
| from pydantic import ValidationError | ||
| import pytest | ||
| import tempfile | ||
| from pathlib import Path | ||
| from pytest_httpserver import HTTPServer, URIPattern | ||
| from werkzeug.wrappers import Request, Response | ||
| from PIL import Image | ||
|
|
||
| from ollama._client import Client, AsyncClient | ||
| from ollama._client import Client, AsyncClient, _copy_tools | ||
|
|
||
|
|
||
| class PrefixPattern(URIPattern): | ||
|
|
@@ -982,3 +983,67 @@ def test_headers(): | |
| ) | ||
| assert client._client.headers['x-custom'] == 'value' | ||
| assert client._client.headers['content-type'] == 'application/json' | ||
|
|
||
|
|
||
| def test_copy_tools(): | ||
| def func1(x: int) -> str: | ||
| """Simple function 1. | ||
| Args: | ||
| x (integer): A number | ||
| """ | ||
| pass | ||
|
|
||
| def func2(y: str) -> int: | ||
| """Simple function 2. | ||
| Args: | ||
| y (string): A string | ||
| """ | ||
| pass | ||
|
|
||
| # Test with list of functions | ||
| tools = list(_copy_tools([func1, func2])) | ||
| assert len(tools) == 2 | ||
| assert tools[0].function.name == 'func1' | ||
| assert tools[1].function.name == 'func2' | ||
|
|
||
| # Test with empty input | ||
| assert list(_copy_tools()) == [] | ||
| assert list(_copy_tools(None)) == [] | ||
| assert list(_copy_tools([])) == [] | ||
|
|
||
| # Test with mix of functions and tool dicts | ||
| tool_dict = { | ||
| 'type': 'function', | ||
| 'function': { | ||
| 'name': 'test', | ||
| 'description': 'Test function', | ||
| 'parameters': { | ||
| 'type': 'object', | ||
| 'properties': {'x': {'type': 'string', 'description': 'A string'}}, | ||
| 'required': ['x'], | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| tool_json = json.loads(json.dumps(tool_dict)) | ||
ParthSareen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| tools = list(_copy_tools([func1, tool_dict, tool_json])) | ||
| assert len(tools) == 3 | ||
| assert tools[0].function.name == 'func1' | ||
| assert tools[1].function.name == 'test' | ||
| assert tools[2].function.name == 'test' | ||
|
|
||
|
|
||
| def test_tool_validation(): | ||
| # Test that malformed tool dictionaries are rejected | ||
| # Raises ValidationError when used as it is a generator | ||
| with pytest.raises(ValidationError): | ||
| invalid_tool = {'type': 'invalid_type', 'function': {'name': 'test'}} | ||
| list(_copy_tools([invalid_tool])) | ||
|
|
||
| # Test missing required fields | ||
| incomplete_tool = { | ||
| 'type': 'function', | ||
| 'function': {'name': 'test'}, # missing description and parameters | ||
|
||
| } | ||
| with pytest.raises(ValidationError): | ||
| list(_copy_tools([incomplete_tool])) | ||
Uh oh!
There was an error while loading. Please reload this page.