11import enum
2- import typing
32import inspect
43import platform
54import packaging .version
5+ from typing import Annotated , Optional , Union , Callable , Literal , Any , get_args , get_origin
66
77current_version = packaging .version .parse (platform .python_version ())
88py_310 = packaging .version .parse ("3.10" )
99
1010if current_version >= py_310 :
1111 from types import UnionType
1212else :
13- UnionType = typing . Union # type: ignore
13+ UnionType = Union # type: ignore
1414
1515try :
1616 from typing import Doc
@@ -22,29 +22,23 @@ class Doc:
2222 def __init__ (self , documentation : str , / ):
2323 self .documentation = documentation
2424
25- __all__ = ("get_function_schema" , "guess_type" , "Doc" )
25+ __all__ = ("get_function_schema" , "guess_type" , "Doc" , "Annotated" )
2626
27- def is_doc_meta (obj ):
27+
28+ def is_doc_meta (obj : Annotated [Any , Doc ("The object to be checked." )]) -> Annotated [bool , Doc ("True if the object is a documentation object, False otherwise." )]:
2829 """
2930 Check if the given object is a documentation object.
30- Parameters:
31- obj (object): The object to be checked.
32- Returns:
33- bool: True if the object is a documentation object, False otherwise.
3431
3532 Example:
3633 >>> is_doc_meta(Doc("This is a documentation object"))
3734 True
3835 """
3936 return getattr (obj , '__class__' ) == Doc and hasattr (obj , 'documentation' )
4037
41- def unwrap_doc (obj : typing .Union [Doc , str ]):
38+
39+ def unwrap_doc (obj : Annotated [Union [Doc , str ], Doc ("The object to get the documentation string from." )]) -> Annotated [str , Doc ("The documentation string." )]:
4240 """
4341 Get the documentation string from the given object.
44- Parameters:
45- obj (Doc | str): The object to get the documentation string from.
46- Returns:
47- str: The documentation string.
4842
4943 Example:
5044 >>> unwrap_doc(Doc("This is a documentation object"))
@@ -58,12 +52,12 @@ def unwrap_doc(obj: typing.Union[Doc, str]):
5852
5953
6054def get_function_schema (
61- func : typing . Annotated [typing . Callable , "The function to get the schema for" ],
62- format : typing . Annotated [
63- typing . Optional [typing . Literal ["openai" , "claude" ]],
64- "The format of the schema to return" ,
55+ func : Annotated [Callable , Doc ( "The function to get the schema for" ) ],
56+ format : Annotated [
57+ Optional [Literal ["openai" , "claude" ]],
58+ Doc ( "The format of the schema to return" ) ,
6559 ] = "openai" ,
66- ) -> typing . Annotated [dict [str , typing . Any ], "The JSON schema for the given function" ]:
60+ ) -> Annotated [dict [str , Any ], Doc ( "The JSON schema for the given function" ) ]:
6761 """
6862 Returns a JSON schema for the given function.
6963
@@ -76,10 +70,10 @@ def get_function_schema(
7670 >>> from typing import Annotated, Optional
7771 >>> import enum
7872 >>> def get_weather(
79- ... city: Annotated[str, "The city to get the weather for"],
73+ ... city: Annotated[str, Doc( "The city to get the weather for") ],
8074 ... unit: Annotated[
8175 ... Optional[str],
82- ... "The unit to return the temperature in",
76+ ... Doc( "The unit to return the temperature in") ,
8377 ... enum.Enum("Unit", "celcius fahrenheit")
8478 ... ] = "celcius",
8579 ... ) -> str:
@@ -115,8 +109,8 @@ def get_function_schema(
115109 "required" : [],
116110 }
117111 for name , param in params .items ():
118- param_args = typing . get_args (param .annotation )
119- is_annotated = typing . get_origin (param .annotation ) is typing . Annotated
112+ param_args = get_args (param .annotation )
113+ is_annotated = get_origin (param .annotation ) is Annotated
120114
121115 enum_ = None
122116 default_value = inspect ._empty
@@ -126,10 +120,17 @@ def get_function_schema(
126120 (T , * _ ) = param_args
127121
128122 # find description in param_args tuple
129- description = next (
130- (unwrap_doc (arg ) for arg in param_args if isinstance (arg , (Doc , str ))),
131- f"The { name } parameter" ,
132- )
123+ try :
124+ description = next (
125+ unwrap_doc (arg )
126+ for arg in param_args if isinstance (arg , Doc )
127+ )
128+ except StopIteration :
129+ try :
130+ description = next (
131+ arg for arg in param_args if isinstance (arg , str ))
132+ except StopIteration :
133+ description = "The {name} parameter"
133134
134135 # find enum in param_args tuple
135136 enum_ = next (
@@ -139,13 +140,13 @@ def get_function_schema(
139140 if isinstance (arg , type ) and issubclass (arg , enum .Enum )
140141 ),
141142 # use typing.Literal as enum if no enum found
142- typing . get_origin (T ) is typing . Literal and typing . get_args (T ) or None ,
143+ get_origin (T ) is Literal and get_args (T ) or None ,
143144 )
144145 else :
145146 T = param .annotation
146147 description = f"The { name } parameter"
147- if typing . get_origin (T ) is typing . Literal :
148- enum_ = typing . get_args (T )
148+ if get_origin (T ) is Literal :
149+ enum_ = get_args (T )
149150
150151 # find default value for param
151152 if param .default is not inspect ._empty :
@@ -157,20 +158,21 @@ def get_function_schema(
157158 }
158159
159160 if enum_ is not None :
160- schema ["properties" ][name ]["enum" ] = [t for t in enum_ if t is not None ]
161+ schema ["properties" ][name ]["enum" ] = [
162+ t for t in enum_ if t is not None ]
161163
162164 if default_value is not inspect ._empty :
163165 schema ["properties" ][name ]["default" ] = default_value
164166
165167 if (
166- typing . get_origin (T ) is not typing . Literal
168+ get_origin (T ) is not Literal
167169 and not isinstance (None , T )
168170 and default_value is inspect ._empty
169171 ):
170172 schema ["required" ].append (name )
171173
172- if typing . get_origin (T ) is typing . Literal :
173- if all (typing . get_args (T )):
174+ if get_origin (T ) is Literal :
175+ if all (get_args (T )):
174176 schema ["required" ].append (name )
175177
176178 parms_key = "input_schema" if format == "claude" else "parameters"
@@ -185,24 +187,25 @@ def get_function_schema(
185187
186188
187189def guess_type (
188- T : typing .Annotated [type , "The type to guess the JSON schema type for" ],
189- ) -> typing .Annotated [
190- typing .Union [str , list [str ]], "str | list of str that representing JSON schema type"
190+ T : Annotated [type , Doc ("The type to guess the JSON schema type for" )],
191+ ) -> Annotated [
192+ Union [str , list [str ]], Doc (
193+ "str | list of str that representing JSON schema type" )
191194]:
192195 """Guesses the JSON schema type for the given python type."""
193196
194197 # special case
195- if T is typing . Any :
198+ if T is Any :
196199 return {}
197200
198- origin = typing . get_origin (T )
201+ origin = get_origin (T )
199202
200- if origin is typing . Annotated :
201- return guess_type (typing . get_args (T )[0 ])
203+ if origin is Annotated :
204+ return guess_type (get_args (T )[0 ])
202205
203206 # hacking around typing modules, `typing.Union` and `types.UnitonType`
204- if origin in [typing . Union , UnionType ]:
205- union_types = [t for t in typing . get_args (T ) if t is not type (None )]
207+ if origin in [Union , UnionType ]:
208+ union_types = [t for t in get_args (T ) if t is not type (None )]
206209 _types = [
207210 guess_type (union_type )
208211 for union_type in union_types
@@ -217,8 +220,8 @@ def guess_type(
217220 return _types [0 ]
218221 return _types
219222
220- if origin is typing . Literal :
221- type_args = typing . Union [tuple (type (arg ) for arg in typing . get_args (T ))]
223+ if origin is Literal :
224+ type_args = Union [tuple (type (arg ) for arg in get_args (T ))]
222225 return guess_type (type_args )
223226 elif origin is list or origin is tuple :
224227 return "array"
0 commit comments