Skip to content

Commit cf18fa3

Browse files
committed
fix bugs
1 parent 33b03a1 commit cf18fa3

File tree

2 files changed

+59
-42
lines changed

2 files changed

+59
-42
lines changed

function_schema/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from importlib.util import module_from_spec, spec_from_file_location
33
import inspect
44
import json
5-
from core import get_function_schema
5+
from .core import get_function_schema
66

77

88
def print_usage():

function_schema/core.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import typing
33
import inspect
44

5+
56
def get_function_schema(
67
func: typing.Annotated[typing.Callable, "The function to get the schema for"]
78
) -> typing.Annotated[dict[str, typing.Any], "The JSON schema for the given function"]:
@@ -26,25 +27,25 @@ def get_function_schema(
2627
... ) -> str:
2728
... \"\"\"Returns the weather for the given city.\"\"\"
2829
... return f"Hello {name}, you are {age} years old."
29-
>>> get_function_schema(get_weather)
30+
>>> get_function_schema(get_weather) # doctest: +SKIP
3031
{
31-
"name": "get_weather",
32-
"description": "Returns the weather for the given city.",
33-
"parameters": {
34-
"type": "object",
35-
"properties": {
36-
"city": {
37-
"type": "string",
38-
"description": "The city to get the weather for"
32+
'name': 'get_weather',
33+
'description': 'Returns the weather for the given city.',
34+
'parameters': {
35+
'type': 'object',
36+
'properties': {
37+
'city': {
38+
'type': 'string',
39+
'description': 'The city to get the weather for'
3940
},
40-
"unit": {
41-
"type": "string",
42-
"description": "The unit to return the temperature in",
43-
"enum": ["celcius", "fahrenheit"],
44-
"default": "celcius"
41+
'unit': {
42+
'type': 'string',
43+
'description': 'The unit to return the temperature in',
44+
'enum': ['celcius', 'fahrenheit'],
45+
'default': 'celcius'
4546
}
4647
},
47-
"required": ["city"]
48+
'required': ['city']
4849
}
4950
}
5051
"""
@@ -64,7 +65,7 @@ def get_function_schema(
6465

6566
if is_annotated:
6667
# first arg is type
67-
(T, _) = param_args
68+
(T, *_) = param_args
6869

6970
# find description in param_args tuple
7071
description = next(
@@ -74,7 +75,12 @@ def get_function_schema(
7475

7576
# find enum in param_args tuple
7677
enum_ = next(
77-
(arg for arg in param_args if isinstance(arg, enum.Enum)), None
78+
(
79+
arg
80+
for arg in param_args
81+
if isinstance(arg, type) and issubclass(arg, enum.Enum)
82+
),
83+
None,
7884
)
7985
else:
8086
T = param.annotation
@@ -90,15 +96,15 @@ def get_function_schema(
9096
}
9197

9298
if enum_ is not None:
93-
schema["properties"][name]["enum"] = enum_.values
99+
schema["properties"][name]["enum"] = [t.name for t in enum_]
94100

95101
if default_value is not inspect._empty:
96102
schema["properties"][name]["default"] = default_value
97103

98-
if not isinstance(None, T):
104+
if not isinstance(None, T) and default_value is inspect._empty:
99105
schema["required"].append(name)
100106
return {
101-
"name": func.__qualname__,
107+
"name": func.__name__,
102108
"description": inspect.getdoc(func),
103109
"parameters": schema,
104110
}
@@ -110,27 +116,38 @@ def guess_type(
110116
typing.Union[str, list[str]], "str | list of str that representing JSON schema type"
111117
]:
112118
"""Guesses the JSON schema type for the given python type."""
113-
_types = []
114119

115120
# hacking around typing modules, `typing.Union` and `types.UnitonType`
116-
if isinstance(1, T):
117-
_types.append("integer")
118-
elif isinstance(1.1, T):
119-
_types.append("number")
120-
121-
if isinstance("", T):
122-
_types.append("string")
123-
if not isinstance(1, T) and isinstance(True, T):
124-
_types.append("boolean")
125-
if isinstance([], T):
126-
_types.append("array")
127-
if isinstance({}, T):
128-
return "object"
129-
130-
if len(_types) == 0:
121+
union_types = typing.get_args(T)
122+
if len(union_types) > 1:
123+
_types = []
124+
for union_type in union_types:
125+
_types.append(guess_type(union_type))
126+
_types = [t for t in _types if t is not None] # exclude None
127+
128+
# number contains integer in JSON schema
129+
if 'number' in _types and 'integer' in _types:
130+
_types.remove('integer')
131+
132+
if len(_types) == 1:
133+
return _types[0]
134+
return _types
135+
136+
if not isinstance(T, type):
137+
return
138+
139+
if T.__name__ == 'NoneType':
140+
return
141+
142+
if issubclass(T, str):
143+
return "string"
144+
if issubclass(T, bool):
145+
return "boolean"
146+
if issubclass(T, float):
147+
return "number"
148+
elif issubclass(T, int):
149+
return "integer"
150+
if T.__name__ == "list":
151+
return "array"
152+
if T.__name__ == "dict":
131153
return "object"
132-
133-
if len(_types) == 1:
134-
return _types[0]
135-
136-
return _types

0 commit comments

Comments
 (0)