Skip to content

Commit d7c4f37

Browse files
committed
support Literal type, code refactories
1 parent 8c35e5f commit d7c4f37

File tree

4 files changed

+96
-21
lines changed

4 files changed

+96
-21
lines changed

function_schema/core.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,11 @@
33
import inspect
44

55

6-
class SchemaFormat(str, enum.Enum):
7-
openai = "openai"
8-
claude = "claude"
9-
10-
116
def get_function_schema(
127
func: typing.Annotated[typing.Callable, "The function to get the schema for"],
138
format: typing.Annotated[
14-
typing.Optional[str], SchemaFormat, "The format of the schema to return"
9+
typing.Optional[typing.Literal["openai", "claude"]],
10+
"The format of the schema to return",
1511
] = "openai",
1612
) -> typing.Annotated[dict[str, typing.Any], "The JSON schema for the given function"]:
1713
"""
@@ -66,7 +62,7 @@ def get_function_schema(
6662
}
6763
for name, param in params.items():
6864
param_args = typing.get_args(param.annotation)
69-
is_annotated = len(param_args) > 1
65+
is_annotated = typing.get_origin(param.annotation) is typing.Annotated
7066

7167
enum_ = None
7268
default_value = inspect._empty
@@ -84,15 +80,18 @@ def get_function_schema(
8480
# find enum in param_args tuple
8581
enum_ = next(
8682
(
87-
arg
83+
arg.name
8884
for arg in param_args
8985
if isinstance(arg, type) and issubclass(arg, enum.Enum)
9086
),
91-
None,
87+
# use typing.Literal as enum if no enum found
88+
typing.get_origin(T) is typing.Literal and typing.get_args(T) or None,
9289
)
9390
else:
9491
T = param.annotation
9592
description = f"The {name} parameter"
93+
if typing.get_origin(T) is typing.Literal:
94+
enum_ = typing.get_args(T)
9695

9796
# find default value for param
9897
if param.default is not inspect._empty:
@@ -104,12 +103,16 @@ def get_function_schema(
104103
}
105104

106105
if enum_ is not None:
107-
schema["properties"][name]["enum"] = [t.name for t in enum_]
106+
schema["properties"][name]["enum"] = [t for t in enum_]
108107

109108
if default_value is not inspect._empty:
110109
schema["properties"][name]["default"] = default_value
111110

112-
if not isinstance(None, T) and default_value is inspect._empty:
111+
if (
112+
typing.get_origin(T) is not typing.Literal
113+
and not isinstance(None, T)
114+
and default_value is inspect._empty
115+
):
113116
schema["required"].append(name)
114117

115118
parms_key = "input_schema" if format == "claude" else "parameters"
@@ -128,9 +131,15 @@ def guess_type(
128131
]:
129132
"""Guesses the JSON schema type for the given python type."""
130133

134+
# special case
135+
if T is typing.Any:
136+
return {}
137+
138+
origin = typing.get_origin(T)
139+
131140
# hacking around typing modules, `typing.Union` and `types.UnitonType`
132-
union_types = typing.get_args(T)
133-
if len(union_types) > 1:
141+
if origin is typing.Union:
142+
union_types = [t for t in typing.get_args(T) if t is not type(None)]
134143
_types = []
135144
for union_type in union_types:
136145
_types.append(guess_type(union_type))
@@ -144,6 +153,13 @@ def guess_type(
144153
return _types[0]
145154
return _types
146155

156+
if origin is typing.Literal:
157+
return "string"
158+
elif origin is list or origin is tuple:
159+
return "array"
160+
elif origin is dict:
161+
return "object"
162+
147163
if not isinstance(T, type):
148164
return
149165

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "function-schema"
3-
version = "0.3.0"
3+
version = "0.3.3"
44
requires-python = ">= 3.9"
55
description = "A small utility to generate JSON schemas for python functions."
66
readme = "README.md"

test/test_guess_type.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@ def test_primitive():
1414
assert guess_type(bool) == "boolean"
1515

1616

17+
def test_typings():
18+
"""Test typing module types"""
19+
assert guess_type(typing.Any) == {}
20+
assert guess_type(typing.List) == "array"
21+
assert guess_type(typing.Dict) == "object"
22+
assert guess_type(typing.Tuple) == "array"
23+
24+
assert guess_type(typing.List[int]) == "array"
25+
assert guess_type(typing.List[str]) == "array"
26+
assert guess_type(typing.List[float]) == "array"
27+
assert guess_type(typing.List[bool]) == "array"
28+
29+
assert guess_type(typing.Dict[str, int]) == "object"
30+
assert guess_type(typing.Dict[str, str]) == "object"
31+
32+
1733
def test_optional():
1834
"""Test optional types"""
1935
assert guess_type(typing.Optional[int]) == "integer"
@@ -51,12 +67,15 @@ def test_union():
5167
]
5268

5369

70+
current_version = packaging.version.parse(platform.python_version())
71+
py_310 = packaging.version.parse("3.10")
72+
73+
74+
@pytest.mark.skipif(
75+
current_version < py_310, reason="Union type is only available in Python 3.10+"
76+
)
5477
def test_union_type():
5578
"""Test union types in Python 3.10+"""
56-
current_version = packaging.version.parse(platform.python_version())
57-
py_310 = packaging.version.parse("3.10")
58-
if current_version < py_310:
59-
pytest.skip("Union type is only available in Python 3.10+")
6079

6180
assert guess_type(int | str) == ["integer", "string"]
6281
assert guess_type(int | float) == "number"
@@ -75,3 +94,9 @@ def test_union_type():
7594
"number",
7695
"boolean",
7796
]
97+
98+
99+
def test_literal_type():
100+
"""Test literal type"""
101+
assert guess_type(typing.Literal["a"]) == "string"
102+
assert guess_type(typing.Literal[1]) == "string"

test/test_schema.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import enum
2-
from typing import Optional, Annotated
2+
from typing import Annotated, Literal
33
from function_schema.core import get_function_schema
44

55

@@ -86,6 +86,7 @@ def func1(
8686

8787
def test_annotated_function_with_enum():
8888
"""Test a function with annotations and enum"""
89+
8990
def func1(
9091
animal: Annotated[
9192
str,
@@ -101,6 +102,39 @@ def func1(
101102
assert (
102103
schema["parameters"]["properties"]["animal"]["type"] == "string"
103104
), "parameter animal should be a string"
105+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
106+
"Cat",
107+
"Dog",
108+
], "parameter animal should have an enum"
109+
110+
111+
def test_literal_type():
112+
"""Test literal type"""
113+
114+
def func1(animal: Annotated[Literal["Cat", "Dog"], "The animal you want to pet"]):
115+
"""My function"""
116+
...
117+
118+
schema = get_function_schema(func1)
119+
print(schema)
104120
assert (
105-
schema["parameters"]["properties"]["animal"]["enum"] == ["Cat", "Dog"]
106-
), "parameter animal should have an enum"
121+
schema["parameters"]["properties"]["animal"]["type"] == "string"
122+
), "parameter animal should be a string"
123+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
124+
"Cat",
125+
"Dog",
126+
], "parameter animal should have an enum"
127+
128+
def func2(animal: Literal["Cat", "Dog"]):
129+
"""My function"""
130+
...
131+
132+
schema = get_function_schema(func2)
133+
assert (
134+
schema["parameters"]["properties"]["animal"]["type"] == "string"
135+
), "parameter animal should be a string"
136+
137+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
138+
"Cat",
139+
"Dog",
140+
], "parameter animal should have an enum"

0 commit comments

Comments
 (0)