1
1
from __future__ import annotations
2
2
3
3
import json
4
+ import logging
4
5
from typing import TYPE_CHECKING , Any , Iterable , cast
5
6
from typing_extensions import TypeVar , TypeGuard , assert_never
6
7
19
20
ParsedChatCompletion ,
20
21
ChatCompletionMessage ,
21
22
ParsedFunctionToolCall ,
22
- ChatCompletionToolParam ,
23
23
ParsedChatCompletionMessage ,
24
+ ChatCompletionFunctionToolParam ,
24
25
completion_create_params ,
25
26
)
26
27
from ..._exceptions import LengthFinishReasonError , ContentFilterFinishReasonError
27
28
from ...types .shared_params import FunctionDefinition
28
29
from ...types .chat .completion_create_params import ResponseFormat as ResponseFormatParam
29
- from ...types .chat .chat_completion_message_tool_call import Function
30
+ from ...types .chat .chat_completion_tool_param import ChatCompletionToolParam
31
+ from ...types .chat .chat_completion_message_function_tool_call import Function
30
32
31
33
ResponseFormatT = TypeVar (
32
34
"ResponseFormatT" ,
35
37
)
36
38
_default_response_format : None = None
37
39
40
+ log : logging .Logger = logging .getLogger ("openai.lib.parsing" )
41
+
42
+
43
+ def is_strict_chat_completion_tool_param (
44
+ tool : ChatCompletionToolParam ,
45
+ ) -> TypeGuard [ChatCompletionFunctionToolParam ]:
46
+ """Check if the given tool is a strict ChatCompletionFunctionToolParam."""
47
+ if not tool ["type" ] == "function" :
48
+ return False
49
+ if tool ["function" ].get ("strict" ) is not True :
50
+ return False
51
+
52
+ return True
53
+
54
+
55
+ def select_strict_chat_completion_tools (
56
+ tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
57
+ ) -> Iterable [ChatCompletionFunctionToolParam ] | NotGiven :
58
+ """Select only the strict ChatCompletionFunctionToolParams from the given tools."""
59
+ if not is_given (tools ):
60
+ return NOT_GIVEN
61
+
62
+ return [t for t in tools if is_strict_chat_completion_tool_param (t )]
63
+
38
64
39
65
def validate_input_tools (
40
66
tools : Iterable [ChatCompletionToolParam ] | NotGiven = NOT_GIVEN ,
41
- ) -> None :
67
+ ) -> Iterable [ ChatCompletionFunctionToolParam ] | NotGiven :
42
68
if not is_given (tools ):
43
- return
69
+ return NOT_GIVEN
44
70
45
71
for tool in tools :
46
72
if tool ["type" ] != "function" :
@@ -54,6 +80,8 @@ def validate_input_tools(
54
80
f"`{ tool ['function' ]['name' ]} ` is not strict. Only `strict` function tools can be auto-parsed"
55
81
)
56
82
83
+ return cast (Iterable [ChatCompletionFunctionToolParam ], tools )
84
+
57
85
58
86
def parse_chat_completion (
59
87
* ,
@@ -95,6 +123,14 @@ def parse_chat_completion(
95
123
type_ = ParsedFunctionToolCall ,
96
124
)
97
125
)
126
+ elif tool_call .type == "custom" :
127
+ # warn user that custom tool calls are not callable here
128
+ log .warning (
129
+ "Custom tool calls are not callable. Ignoring tool call: %s - %s" ,
130
+ tool_call .id ,
131
+ tool_call .custom .name ,
132
+ stacklevel = 2 ,
133
+ )
98
134
elif TYPE_CHECKING : # type: ignore[unreachable]
99
135
assert_never (tool_call )
100
136
else :
@@ -129,13 +165,15 @@ def parse_chat_completion(
129
165
)
130
166
131
167
132
- def get_input_tool_by_name (* , input_tools : list [ChatCompletionToolParam ], name : str ) -> ChatCompletionToolParam | None :
133
- return next ((t for t in input_tools if t .get ("function" , {}).get ("name" ) == name ), None )
168
+ def get_input_tool_by_name (
169
+ * , input_tools : list [ChatCompletionToolParam ], name : str
170
+ ) -> ChatCompletionFunctionToolParam | None :
171
+ return next ((t for t in input_tools if t ["type" ] == "function" and t .get ("function" , {}).get ("name" ) == name ), None )
134
172
135
173
136
174
def parse_function_tool_arguments (
137
175
* , input_tools : list [ChatCompletionToolParam ], function : Function | ParsedFunction
138
- ) -> object :
176
+ ) -> object | None :
139
177
input_tool = get_input_tool_by_name (input_tools = input_tools , name = function .name )
140
178
if not input_tool :
141
179
return None
@@ -149,7 +187,7 @@ def parse_function_tool_arguments(
149
187
if not input_fn .get ("strict" ):
150
188
return None
151
189
152
- return json .loads (function .arguments )
190
+ return json .loads (function .arguments ) # type: ignore[no-any-return]
153
191
154
192
155
193
def maybe_parse_content (
@@ -209,6 +247,9 @@ def is_response_format_param(response_format: object) -> TypeGuard[ResponseForma
209
247
210
248
211
249
def is_parseable_tool (input_tool : ChatCompletionToolParam ) -> bool :
250
+ if input_tool ["type" ] != "function" :
251
+ return False
252
+
212
253
input_fn = cast (object , input_tool .get ("function" ))
213
254
if isinstance (input_fn , PydanticFunctionTool ):
214
255
return True
0 commit comments