Skip to content

Commit 6a50a38

Browse files
authored
Merge pull request #4 from comfuture/literal-type
Fix enum handling in get_function_schema
2 parents 6ccdd51 + 65ff7d0 commit 6a50a38

File tree

4 files changed

+79
-20
lines changed

4 files changed

+79
-20
lines changed

function_schema/core.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
py_310 = packaging.version.parse("3.10")
99

1010
if current_version >= py_310:
11-
import types
1211
from types import UnionType
1312
else:
1413
UnionType = typing.Union # type: ignore
@@ -114,7 +113,7 @@ def get_function_schema(
114113
}
115114

116115
if enum_ is not None:
117-
schema["properties"][name]["enum"] = [t for t in enum_]
116+
schema["properties"][name]["enum"] = [t for t in enum_ if t is not None]
118117

119118
if default_value is not inspect._empty:
120119
schema["properties"][name]["default"] = default_value
@@ -126,8 +125,14 @@ def get_function_schema(
126125
):
127126
schema["required"].append(name)
128127

128+
if typing.get_origin(T) is typing.Literal:
129+
if all(typing.get_args(T)):
130+
schema["required"].append(name)
131+
129132
parms_key = "input_schema" if format == "claude" else "parameters"
130133

134+
schema["required"] = list(set(schema["required"]))
135+
131136
return {
132137
"name": func.__name__,
133138
"description": inspect.getdoc(func),
@@ -148,13 +153,17 @@ def guess_type(
148153

149154
origin = typing.get_origin(T)
150155

156+
if origin is typing.Annotated:
157+
return guess_type(typing.get_args(T)[0])
158+
151159
# hacking around typing modules, `typing.Union` and `types.UnitonType`
152-
if origin is typing.Union or origin is UnionType:
160+
if origin in [typing.Union, UnionType]:
153161
union_types = [t for t in typing.get_args(T) if t is not type(None)]
154-
_types = []
155-
for union_type in union_types:
156-
_types.append(guess_type(union_type))
157-
_types = [t for t in _types if t is not None] # exclude None
162+
_types = [
163+
guess_type(union_type)
164+
for union_type in union_types
165+
if guess_type(union_type) is not None
166+
]
158167

159168
# number contains integer in JSON schema
160169
if "number" in _types and "integer" in _types:
@@ -165,7 +174,8 @@ def guess_type(
165174
return _types
166175

167176
if origin is typing.Literal:
168-
return "string"
177+
type_args = typing.Union[tuple(type(arg) for arg in typing.get_args(T))]
178+
return guess_type(type_args)
169179
elif origin is list or origin is tuple:
170180
return "array"
171181
elif origin is dict:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "function-schema"
3-
version = "0.3.3"
3+
version = "0.3.4"
44
requires-python = ">= 3.9"
55
description = "A small utility to generate JSON schemas for python functions."
66
readme = "README.md"

test/test_guess_type.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,21 @@ def test_union_type():
9898

9999
def test_literal_type():
100100
"""Test literal type"""
101-
assert guess_type(typing.Literal["a"]) == "string"
102-
assert guess_type(typing.Literal[1]) == "string"
101+
assert guess_type(typing.Literal["a"]) == "string", "should be string"
102+
assert guess_type(typing.Literal[1]) == "integer", "should be integer"
103+
assert guess_type(typing.Literal[1.0]) == "number", "should be number"
104+
105+
assert set(guess_type(typing.Literal["a", 1, None])) == {
106+
"string",
107+
"integer",
108+
}, "should be string or integer, omit None"
109+
110+
assert set(guess_type(typing.Literal["a", 1])) == {"string", "integer"}
111+
assert set(guess_type(typing.Literal["a", 1.0])) == {"string", "number"}
112+
assert set(guess_type(typing.Literal["a", 1.1])) == {"string", "number"}
113+
assert set(guess_type(typing.Literal["a", 1, 1.0])) == {
114+
"string",
115+
"number",
116+
}, "should omit integer if number is present"
117+
118+
assert set(guess_type(typing.Literal["a", 1, 1.0, None])) == {"string", "number"}

test/test_schema.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def func1(a: int, b: str, c: float = 1.0):
4747
assert (
4848
schema["parameters"]["properties"]["c"]["default"] == 1.0
4949
), "c should have a default value of 1.0"
50-
assert schema["parameters"]["required"] == [
51-
"a",
52-
"b",
53-
], "parameters with no default value should be required"
50+
assert (
51+
"a" in schema["parameters"]["required"]
52+
and "b" in schema["parameters"]["required"]
53+
), "parameters with no default value should be required"
5454

5555

5656
def test_annotated_function():
@@ -78,10 +78,10 @@ def func1(
7878
schema["parameters"]["properties"]["b"]["description"] == "A string parameter"
7979
), "parameter b should have a description"
8080

81-
assert schema["parameters"]["required"] == [
82-
"a",
83-
"b",
84-
], "parameters with no default value should be required"
81+
assert (
82+
"a" in schema["parameters"]["required"]
83+
and "b" in schema["parameters"]["required"]
84+
), "parameters with no default value should be required"
8585

8686

8787
def test_annotated_function_with_enum():
@@ -116,7 +116,6 @@ def func1(animal: Annotated[Literal["Cat", "Dog"], "The animal you want to pet"]
116116
...
117117

118118
schema = get_function_schema(func1)
119-
print(schema)
120119
assert (
121120
schema["parameters"]["properties"]["animal"]["type"] == "string"
122121
), "parameter animal should be a string"
@@ -138,3 +137,37 @@ def func2(animal: Literal["Cat", "Dog"]):
138137
"Cat",
139138
"Dog",
140139
], "parameter animal should have an enum"
140+
141+
assert "animal" in schema["parameters"]["required"], "animal is required"
142+
143+
def func3(animal: Literal["Cat", "Dog", 1]):
144+
"""My function"""
145+
...
146+
147+
schema = get_function_schema(func3)
148+
assert schema["parameters"]["properties"]["animal"]["type"] == [
149+
"string",
150+
"integer",
151+
], "parameter animal should be a string"
152+
153+
def func4(animal: Literal["Cat", "Dog", 1, None]):
154+
"""My function"""
155+
...
156+
157+
schema = get_function_schema(func4)
158+
159+
assert schema["parameters"]["properties"]["animal"]["type"] == [
160+
"string",
161+
"integer",
162+
], "parameter animal should be a string"
163+
164+
assert schema["parameters"]["properties"]["animal"]["enum"] == [
165+
"Cat",
166+
"Dog",
167+
1,
168+
], "parameter animal should have an enum"
169+
170+
assert (
171+
schema["parameters"].get("required") is None
172+
or "animal" not in schema["parameters"]["required"]
173+
), "animal should not be required"

0 commit comments

Comments
 (0)