Skip to content

Commit 8c7aefa

Browse files
committed
Enhance Doc in Annotation and implement its tests
1 parent 9f7659a commit 8c7aefa

File tree

2 files changed

+73
-44
lines changed

2 files changed

+73
-44
lines changed

function_schema/core.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import enum
2-
import typing
32
import inspect
43
import platform
54
import packaging.version
5+
from typing import Annotated, Optional, Union, Callable, Literal, Any, get_args, get_origin
66

77
current_version = packaging.version.parse(platform.python_version())
88
py_310 = packaging.version.parse("3.10")
99

1010
if current_version >= py_310:
1111
from types import UnionType
1212
else:
13-
UnionType = typing.Union # type: ignore
13+
UnionType = Union # type: ignore
1414

1515
try:
1616
from typing import Doc
@@ -22,29 +22,23 @@ class Doc:
2222
def __init__(self, documentation: str, /):
2323
self.documentation = documentation
2424

25-
__all__ = ("get_function_schema", "guess_type", "Doc")
25+
__all__ = ("get_function_schema", "guess_type", "Doc", "Annotated")
2626

27-
def is_doc_meta(obj):
27+
28+
def is_doc_meta(obj: Annotated[Any, Doc("The object to be checked.")]) -> Annotated[bool, Doc("True if the object is a documentation object, False otherwise.")]:
2829
"""
2930
Check if the given object is a documentation object.
30-
Parameters:
31-
obj (object): The object to be checked.
32-
Returns:
33-
bool: True if the object is a documentation object, False otherwise.
3431
3532
Example:
3633
>>> is_doc_meta(Doc("This is a documentation object"))
3734
True
3835
"""
3936
return getattr(obj, '__class__') == Doc and hasattr(obj, 'documentation')
4037

41-
def unwrap_doc(obj: typing.Union[Doc, str]):
38+
39+
def unwrap_doc(obj: Annotated[Union[Doc, str], Doc("The object to get the documentation string from.")]) -> Annotated[str, Doc("The documentation string.")]:
4240
"""
4341
Get the documentation string from the given object.
44-
Parameters:
45-
obj (Doc | str): The object to get the documentation string from.
46-
Returns:
47-
str: The documentation string.
4842
4943
Example:
5044
>>> unwrap_doc(Doc("This is a documentation object"))
@@ -58,12 +52,12 @@ def unwrap_doc(obj: typing.Union[Doc, str]):
5852

5953

6054
def get_function_schema(
61-
func: typing.Annotated[typing.Callable, "The function to get the schema for"],
62-
format: typing.Annotated[
63-
typing.Optional[typing.Literal["openai", "claude"]],
64-
"The format of the schema to return",
55+
func: Annotated[Callable, Doc("The function to get the schema for")],
56+
format: Annotated[
57+
Optional[Literal["openai", "claude"]],
58+
Doc("The format of the schema to return"),
6559
] = "openai",
66-
) -> typing.Annotated[dict[str, typing.Any], "The JSON schema for the given function"]:
60+
) -> Annotated[dict[str, Any], Doc("The JSON schema for the given function")]:
6761
"""
6862
Returns a JSON schema for the given function.
6963
@@ -76,10 +70,10 @@ def get_function_schema(
7670
>>> from typing import Annotated, Optional
7771
>>> import enum
7872
>>> def get_weather(
79-
... city: Annotated[str, "The city to get the weather for"],
73+
... city: Annotated[str, Doc("The city to get the weather for")],
8074
... unit: Annotated[
8175
... Optional[str],
82-
... "The unit to return the temperature in",
76+
... Doc("The unit to return the temperature in"),
8377
... enum.Enum("Unit", "celcius fahrenheit")
8478
... ] = "celcius",
8579
... ) -> str:
@@ -115,8 +109,8 @@ def get_function_schema(
115109
"required": [],
116110
}
117111
for name, param in params.items():
118-
param_args = typing.get_args(param.annotation)
119-
is_annotated = typing.get_origin(param.annotation) is typing.Annotated
112+
param_args = get_args(param.annotation)
113+
is_annotated = get_origin(param.annotation) is Annotated
120114

121115
enum_ = None
122116
default_value = inspect._empty
@@ -126,10 +120,17 @@ def get_function_schema(
126120
(T, *_) = param_args
127121

128122
# find description in param_args tuple
129-
description = next(
130-
(unwrap_doc(arg) for arg in param_args if isinstance(arg, (Doc, str))),
131-
f"The {name} parameter",
132-
)
123+
try:
124+
description = next(
125+
unwrap_doc(arg)
126+
for arg in param_args if isinstance(arg, Doc)
127+
)
128+
except StopIteration:
129+
try:
130+
description = next(
131+
arg for arg in param_args if isinstance(arg, str))
132+
except StopIteration:
133+
description = "The {name} parameter"
133134

134135
# find enum in param_args tuple
135136
enum_ = next(
@@ -139,13 +140,13 @@ def get_function_schema(
139140
if isinstance(arg, type) and issubclass(arg, enum.Enum)
140141
),
141142
# use typing.Literal as enum if no enum found
142-
typing.get_origin(T) is typing.Literal and typing.get_args(T) or None,
143+
get_origin(T) is Literal and get_args(T) or None,
143144
)
144145
else:
145146
T = param.annotation
146147
description = f"The {name} parameter"
147-
if typing.get_origin(T) is typing.Literal:
148-
enum_ = typing.get_args(T)
148+
if get_origin(T) is Literal:
149+
enum_ = get_args(T)
149150

150151
# find default value for param
151152
if param.default is not inspect._empty:
@@ -157,20 +158,21 @@ def get_function_schema(
157158
}
158159

159160
if enum_ is not None:
160-
schema["properties"][name]["enum"] = [t for t in enum_ if t is not None]
161+
schema["properties"][name]["enum"] = [
162+
t for t in enum_ if t is not None]
161163

162164
if default_value is not inspect._empty:
163165
schema["properties"][name]["default"] = default_value
164166

165167
if (
166-
typing.get_origin(T) is not typing.Literal
168+
get_origin(T) is not Literal
167169
and not isinstance(None, T)
168170
and default_value is inspect._empty
169171
):
170172
schema["required"].append(name)
171173

172-
if typing.get_origin(T) is typing.Literal:
173-
if all(typing.get_args(T)):
174+
if get_origin(T) is Literal:
175+
if all(get_args(T)):
174176
schema["required"].append(name)
175177

176178
parms_key = "input_schema" if format == "claude" else "parameters"
@@ -185,24 +187,25 @@ def get_function_schema(
185187

186188

187189
def guess_type(
188-
T: typing.Annotated[type, "The type to guess the JSON schema type for"],
189-
) -> typing.Annotated[
190-
typing.Union[str, list[str]], "str | list of str that representing JSON schema type"
190+
T: Annotated[type, Doc("The type to guess the JSON schema type for")],
191+
) -> Annotated[
192+
Union[str, list[str]], Doc(
193+
"str | list of str that representing JSON schema type")
191194
]:
192195
"""Guesses the JSON schema type for the given python type."""
193196

194197
# special case
195-
if T is typing.Any:
198+
if T is Any:
196199
return {}
197200

198-
origin = typing.get_origin(T)
201+
origin = get_origin(T)
199202

200-
if origin is typing.Annotated:
201-
return guess_type(typing.get_args(T)[0])
203+
if origin is Annotated:
204+
return guess_type(get_args(T)[0])
202205

203206
# hacking around typing modules, `typing.Union` and `types.UnitonType`
204-
if origin in [typing.Union, UnionType]:
205-
union_types = [t for t in typing.get_args(T) if t is not type(None)]
207+
if origin in [Union, UnionType]:
208+
union_types = [t for t in get_args(T) if t is not type(None)]
206209
_types = [
207210
guess_type(union_type)
208211
for union_type in union_types
@@ -217,8 +220,8 @@ def guess_type(
217220
return _types[0]
218221
return _types
219222

220-
if origin is typing.Literal:
221-
type_args = typing.Union[tuple(type(arg) for arg in typing.get_args(T))]
223+
if origin is Literal:
224+
type_args = Union[tuple(type(arg) for arg in get_args(T))]
222225
return guess_type(type_args)
223226
elif origin is list or origin is tuple:
224227
return "array"

test/test_pep_0727_doc.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,29 @@ def func1(a: Annotated[str, Enum("Candidates", "a b c"), Doc("A string parameter
3131
assert schema["parameters"]["properties"]["a"]["description"] == "A string parameter", "parameter a should have a description"
3232
assert schema["parameters"]["properties"]["a"]["enum"] == [
3333
"a", "b", "c"], "parameter a should have enum values"
34+
35+
36+
def test_multiple_docs_in_annotation():
37+
"""Test a function with annotations with multiple Doc"""
38+
def func1(a: Annotated[int, Doc("An integer parameter"), Doc("A number")]):
39+
"""My function"""
40+
...
41+
42+
schema = get_function_schema(func1)
43+
assert schema["name"] == "func1", "Function name should be func1"
44+
assert schema["description"] == "My function", "Function description should be there"
45+
assert schema["parameters"]["properties"]["a"]["type"] == "number", "parameter a should be an integer"
46+
assert schema["parameters"]["properties"]["a"]["description"] == "An integer parameter", "first description should be used"
47+
48+
49+
def test_mixed_docs_in_annotation():
50+
"""Test a function with annotations with mixed Doc and strings"""
51+
def func1(a: Annotated[int, "An integer parameter", Doc("A number")]):
52+
"""My function"""
53+
...
54+
55+
schema = get_function_schema(func1)
56+
assert schema["name"] == "func1", "Function name should be func1"
57+
assert schema["description"] == "My function", "Function description should be there"
58+
assert schema["parameters"]["properties"]["a"]["type"] == "number", "parameter a should be an integer"
59+
assert schema["parameters"]["properties"]["a"]["description"] == "A number", "`Doc` should be used rather than string"

0 commit comments

Comments
 (0)