11import enum
22import typing
33import inspect
4+ import platform
5+ import packaging .version
46
7+ current_version = packaging .version .parse (platform .python_version ())
8+ py_310 = packaging .version .parse ("3.10" )
59
6- class SchemaFormat (str , enum .Enum ):
7- openai = "openai"
8- claude = "claude"
10+ if current_version >= py_310 :
11+ import types
12+ from types import UnionType
13+ else :
14+ UnionType = typing .Union # type: ignore
915
1016
1117def get_function_schema (
1218 func : typing .Annotated [typing .Callable , "The function to get the schema for" ],
1319 format : typing .Annotated [
14- typing .Optional [str ], SchemaFormat , "The format of the schema to return"
20+ typing .Optional [typing .Literal ["openai" , "claude" ]],
21+ "The format of the schema to return" ,
1522 ] = "openai" ,
1623) -> typing .Annotated [dict [str , typing .Any ], "The JSON schema for the given function" ]:
1724 """
@@ -66,7 +73,7 @@ def get_function_schema(
6673 }
6774 for name , param in params .items ():
6875 param_args = typing .get_args (param .annotation )
69- is_annotated = len ( param_args ) > 1
76+ is_annotated = typing . get_origin ( param . annotation ) is typing . Annotated
7077
7178 enum_ = None
7279 default_value = inspect ._empty
@@ -84,15 +91,18 @@ def get_function_schema(
8491 # find enum in param_args tuple
8592 enum_ = next (
8693 (
87- arg
94+ [ e . name for e in arg ]
8895 for arg in param_args
8996 if isinstance (arg , type ) and issubclass (arg , enum .Enum )
9097 ),
91- None ,
98+ # use typing.Literal as enum if no enum found
99+ typing .get_origin (T ) is typing .Literal and typing .get_args (T ) or None ,
92100 )
93101 else :
94102 T = param .annotation
95103 description = f"The { name } parameter"
104+ if typing .get_origin (T ) is typing .Literal :
105+ enum_ = typing .get_args (T )
96106
97107 # find default value for param
98108 if param .default is not inspect ._empty :
@@ -104,12 +114,16 @@ def get_function_schema(
104114 }
105115
106116 if enum_ is not None :
107- schema ["properties" ][name ]["enum" ] = [t . name for t in enum_ ]
117+ schema ["properties" ][name ]["enum" ] = [t for t in enum_ ]
108118
109119 if default_value is not inspect ._empty :
110120 schema ["properties" ][name ]["default" ] = default_value
111121
112- if not isinstance (None , T ) and default_value is inspect ._empty :
122+ if (
123+ typing .get_origin (T ) is not typing .Literal
124+ and not isinstance (None , T )
125+ and default_value is inspect ._empty
126+ ):
113127 schema ["required" ].append (name )
114128
115129 parms_key = "input_schema" if format == "claude" else "parameters"
@@ -128,9 +142,15 @@ def guess_type(
128142]:
129143 """Guesses the JSON schema type for the given python type."""
130144
145+ # special case
146+ if T is typing .Any :
147+ return {}
148+
149+ origin = typing .get_origin (T )
150+
131151 # hacking around typing modules, `typing.Union` and `types.UnitonType`
132- union_types = typing .get_args ( T )
133- if len ( union_types ) > 1 :
152+ if origin is typing .Union or origin is UnionType :
153+ union_types = [ t for t in typing . get_args ( T ) if t is not type ( None )]
134154 _types = []
135155 for union_type in union_types :
136156 _types .append (guess_type (union_type ))
@@ -144,6 +164,13 @@ def guess_type(
144164 return _types [0 ]
145165 return _types
146166
167+ if origin is typing .Literal :
168+ return "string"
169+ elif origin is list or origin is tuple :
170+ return "array"
171+ elif origin is dict :
172+ return "object"
173+
147174 if not isinstance (T , type ):
148175 return
149176
0 commit comments