Skip to content

Commit 6ccdd51

Browse files
authored
Merge pull request #3 from comfuture/literal-type
Support Literal type and code refactorings
2 parents 8c35e5f + e070fca commit 6ccdd51

File tree

4 files changed

+105
-19
lines changed

4 files changed

+105
-19
lines changed

function_schema/core.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
import enum
22
import typing
33
import inspect
4+
import platform
5+
import packaging.version
46

7+
current_version = packaging.version.parse(platform.python_version())
8+
py_310 = packaging.version.parse("3.10")
59

6-
class SchemaFormat(str, enum.Enum):
7-
openai = "openai"
8-
claude = "claude"
10+
if current_version >= py_310:
11+
import types
12+
from types import UnionType
13+
else:
14+
UnionType = typing.Union # type: ignore
915

1016

1117
def get_function_schema(
1218
func: typing.Annotated[typing.Callable, "The function to get the schema for"],
1319
format: typing.Annotated[
14-
typing.Optional[str], SchemaFormat, "The format of the schema to return"
20+
typing.Optional[typing.Literal["openai", "claude"]],
21+
"The format of the schema to return",
1522
] = "openai",
1623
) -> typing.Annotated[dict[str, typing.Any], "The JSON schema for the given function"]:
1724
"""
@@ -66,7 +73,7 @@ def get_function_schema(
6673
}
6774
for name, param in params.items():
6875
param_args = typing.get_args(param.annotation)
69-
is_annotated = len(param_args) > 1
76+
is_annotated = typing.get_origin(param.annotation) is typing.Annotated
7077

7178
enum_ = None
7279
default_value = inspect._empty
@@ -84,15 +91,18 @@ def get_function_schema(
8491
# find enum in param_args tuple
8592
enum_ = next(
8693
(
87-
arg
94+
[e.name for e in arg]
8895
for arg in param_args
8996
if isinstance(arg, type) and issubclass(arg, enum.Enum)
9097
),
91-
None,
98+
# use typing.Literal as enum if no enum found
99+
typing.get_origin(T) is typing.Literal and typing.get_args(T) or None,
92100
)
93101
else:
94102
T = param.annotation
95103
description = f"The {name} parameter"
104+
if typing.get_origin(T) is typing.Literal:
105+
enum_ = typing.get_args(T)
96106

97107
# find default value for param
98108
if param.default is not inspect._empty:
@@ -104,12 +114,16 @@ def get_function_schema(
104114
}
105115

106116
if enum_ is not None:
107-
schema["properties"][name]["enum"] = [t.name for t in enum_]
117+
schema["properties"][name]["enum"] = [t for t in enum_]
108118

109119
if default_value is not inspect._empty:
110120
schema["properties"][name]["default"] = default_value
111121

112-
if not isinstance(None, T) and default_value is inspect._empty:
122+
if (
123+
typing.get_origin(T) is not typing.Literal
124+
and not isinstance(None, T)
125+
and default_value is inspect._empty
126+
):
113127
schema["required"].append(name)
114128

115129
parms_key = "input_schema" if format == "claude" else "parameters"
@@ -128,9 +142,15 @@ def guess_type(
128142
]:
129143
"""Guesses the JSON schema type for the given python type."""
130144

145+
# special case
146+
if T is typing.Any:
147+
return {}
148+
149+
origin = typing.get_origin(T)
150+
131151
# hacking around typing modules, `typing.Union` and `types.UnitonType`
132-
union_types = typing.get_args(T)
133-
if len(union_types) > 1:
152+
if origin is typing.Union or origin is UnionType:
153+
union_types = [t for t in typing.get_args(T) if t is not type(None)]
134154
_types = []
135155
for union_type in union_types:
136156
_types.append(guess_type(union_type))
@@ -144,6 +164,13 @@ def guess_type(
144164
return _types[0]
145165
return _types
146166

167+
if origin is typing.Literal:
168+
return "string"
169+
elif origin is list or origin is tuple:
170+
return "array"
171+
elif origin is dict:
172+
return "object"
173+
147174
if not isinstance(T, type):
148175
return
149176

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "function-schema"
3-
version = "0.3.0"
3+
version = "0.3.3"
44
requires-python = ">= 3.9"
55
description = "A small utility to generate JSON schemas for python functions."
66
readme = "README.md"

test/test_guess_type.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@ def test_primitive():
1414
assert guess_type(bool) == "boolean"
1515

1616

17+
def test_typings():
18+
"""Test typing module types"""
19+
assert guess_type(typing.Any) == {}
20+
assert guess_type(typing.List) == "array"
21+
assert guess_type(typing.Dict) == "object"
22+
assert guess_type(typing.Tuple) == "array"
23+
24+
assert guess_type(typing.List[int]) == "array"
25+
assert guess_type(typing.List[str]) == "array"
26+
assert guess_type(typing.List[float]) == "array"
27+
assert guess_type(typing.List[bool]) == "array"
28+
29+
assert guess_type(typing.Dict[str, int]) == "object"
30+
assert guess_type(typing.Dict[str, str]) == "object"
31+
32+
1733
def test_optional():
1834
"""Test optional types"""
1935
assert guess_type(typing.Optional[int]) == "integer"
@@ -51,12 +67,15 @@ def test_union():
5167
]
5268

5369

70+
current_version = packaging.version.parse(platform.python_version())
71+
py_310 = packaging.version.parse("3.10")
72+
73+
74+
@pytest.mark.skipif(
75+
current_version < py_310, reason="Union type is only available in Python 3.10+"
76+
)
5477
def test_union_type():
5578
"""Test union types in Python 3.10+"""
56-
current_version = packaging.version.parse(platform.python_version())
57-
py_310 = packaging.version.parse("3.10")
58-
if current_version < py_310:
59-
pytest.skip("Union type is only available in Python 3.10+")
6079

6180
assert guess_type(int | str) == ["integer", "string"]
6281
assert guess_type(int | float) == "number"
@@ -75,3 +94,9 @@ def test_union_type():
7594
"number",
7695
"boolean",
7796
]
97+
98+
99+
def test_literal_type():
100+
"""Test literal type"""
101+
assert guess_type(typing.Literal["a"]) == "string"
102+
assert guess_type(typing.Literal[1]) == "string"

test/test_schema.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import Optional, Annotated
2+
from typing import Annotated, Literal
33
from function_schema.core import get_function_schema
44

55

@@ -86,6 +86,7 @@ def func1(
8686

8787
def test_annotated_function_with_enum():
8888
"""Test a function with annotations and enum"""
89+
8990
def func1(
9091
animal: Annotated[
9192
str,
@@ -101,6 +102,39 @@ def func1(
101102
assert (
102103
schema["parameters"]["properties"]["animal"]["type"] == "string"
103104
), "parameter animal should be a string"
105+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
106+
"Cat",
107+
"Dog",
108+
], "parameter animal should have an enum"
109+
110+
111+
def test_literal_type():
112+
"""Test literal type"""
113+
114+
def func1(animal: Annotated[Literal["Cat", "Dog"], "The animal you want to pet"]):
115+
"""My function"""
116+
...
117+
118+
schema = get_function_schema(func1)
119+
print(schema)
104120
assert (
105-
schema["parameters"]["properties"]["animal"]["enum"] == ["Cat", "Dog"]
106-
), "parameter animal should have an enum"
121+
schema["parameters"]["properties"]["animal"]["type"] == "string"
122+
), "parameter animal should be a string"
123+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
124+
"Cat",
125+
"Dog",
126+
], "parameter animal should have an enum"
127+
128+
def func2(animal: Literal["Cat", "Dog"]):
129+
"""My function"""
130+
...
131+
132+
schema = get_function_schema(func2)
133+
assert (
134+
schema["parameters"]["properties"]["animal"]["type"] == "string"
135+
), "parameter animal should be a string"
136+
137+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
138+
"Cat",
139+
"Dog",
140+
], "parameter animal should have an enum"

0 commit comments

Comments
 (0)