6
6
7
7
import pydantic
8
8
import pydantic_core
9
- from typing_extensions import Self , assert_never
10
9
11
10
from ._utils import now_utc as _now_utc
12
11
from .exceptions import UnexpectedModelBehavior
@@ -168,33 +167,17 @@ def has_content(self) -> bool:
168
167
return bool (self .content )
169
168
170
169
171
- @dataclass
172
- class ArgsJson :
173
- """Tool arguments as a JSON string."""
174
-
175
- args_json : str
176
- """A JSON string of arguments."""
177
-
178
-
179
- @dataclass
180
- class ArgsDict :
181
- """Tool arguments as a Python dictionary."""
182
-
183
- args_dict : dict [str , Any ]
184
- """A python dictionary of arguments."""
185
-
186
-
187
170
@dataclass
188
171
class ToolCallPart :
189
172
"""A tool call from a model."""
190
173
191
174
tool_name : str
192
175
"""The name of the tool to call."""
193
176
194
- args : ArgsJson | ArgsDict
177
+ args : str | dict [ str , Any ]
195
178
"""The arguments to pass to the tool.
196
179
197
- Either as JSON or a Python dictionary depending on how data was returned .
180
+ This is stored either as a JSON string or a Python dictionary depending on how data was received .
198
181
"""
199
182
200
183
tool_call_id : str | None = None
@@ -203,24 +186,14 @@ class ToolCallPart:
203
186
part_kind : Literal ['tool-call' ] = 'tool-call'
204
187
"""Part type identifier, this is available on all parts as a discriminator."""
205
188
206
- @classmethod
207
- def from_raw_args (cls , tool_name : str , args : str | dict [str , Any ], tool_call_id : str | None = None ) -> Self :
208
- """Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
209
- if isinstance (args , str ):
210
- return cls (tool_name , ArgsJson (args ), tool_call_id )
211
- elif isinstance (args , dict ):
212
- return cls (tool_name , ArgsDict (args ), tool_call_id )
213
- else :
214
- assert_never (args )
215
-
216
189
def args_as_dict (self ) -> dict [str , Any ]:
217
190
"""Return the arguments as a Python dictionary.
218
191
219
192
This is just for convenience with models that require dicts as input.
220
193
"""
221
- if isinstance (self .args , ArgsDict ):
222
- return self .args . args_dict
223
- args = pydantic_core .from_json (self .args . args_json )
194
+ if isinstance (self .args , dict ):
195
+ return self .args
196
+ args = pydantic_core .from_json (self .args )
224
197
assert isinstance (args , dict ), 'args should be a dict'
225
198
return cast (dict [str , Any ], args )
226
199
@@ -229,16 +202,18 @@ def args_as_json_str(self) -> str:
229
202
230
203
This is just for convenience with models that require JSON strings as input.
231
204
"""
232
- if isinstance (self .args , ArgsJson ):
233
- return self .args . args_json
234
- return pydantic_core .to_json (self .args . args_dict ).decode ()
205
+ if isinstance (self .args , str ):
206
+ return self .args
207
+ return pydantic_core .to_json (self .args ).decode ()
235
208
236
209
def has_content (self ) -> bool :
237
210
"""Return `True` if the arguments contain any data."""
238
- if isinstance (self .args , ArgsDict ):
239
- return any (self .args .args_dict .values ())
211
+ if isinstance (self .args , dict ):
212
+ # TODO: This should probably return True if you have the value False, or 0, etc.
213
+ # It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214
+ return any (self .args .values ())
240
215
else :
241
- return bool (self .args . args_json )
216
+ return bool (self .args )
242
217
243
218
244
219
ModelResponsePart = Annotated [Union [TextPart , ToolCallPart ], pydantic .Discriminator ('part_kind' )]
@@ -331,7 +306,7 @@ def as_part(self) -> ToolCallPart | None:
331
306
if self .tool_name_delta is None or self .args_delta is None :
332
307
return None
333
308
334
- return ToolCallPart . from_raw_args (
309
+ return ToolCallPart (
335
310
self .tool_name_delta ,
336
311
self .args_delta ,
337
312
self .tool_call_id ,
@@ -396,7 +371,7 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPa
396
371
397
372
# If we now have enough data to create a full ToolCallPart, do so
398
373
if delta .tool_name_delta is not None and delta .args_delta is not None :
399
- return ToolCallPart . from_raw_args (
374
+ return ToolCallPart (
400
375
delta .tool_name_delta ,
401
376
delta .args_delta ,
402
377
delta .tool_call_id ,
@@ -412,15 +387,15 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
412
387
part = replace (part , tool_name = tool_name )
413
388
414
389
if isinstance (self .args_delta , str ):
415
- if not isinstance (part .args , ArgsJson ):
390
+ if not isinstance (part .args , str ):
416
391
raise UnexpectedModelBehavior (f'Cannot apply JSON deltas to non-JSON tool arguments ({ part = } , { self = } )' )
417
- updated_json = part .args . args_json + self .args_delta
418
- part = replace (part , args = ArgsJson ( updated_json ) )
392
+ updated_json = part .args + self .args_delta
393
+ part = replace (part , args = updated_json )
419
394
elif isinstance (self .args_delta , dict ):
420
- if not isinstance (part .args , ArgsDict ):
395
+ if not isinstance (part .args , dict ):
421
396
raise UnexpectedModelBehavior (f'Cannot apply dict deltas to non-dict tool arguments ({ part = } , { self = } )' )
422
- updated_dict = {** (part .args . args_dict or {}), ** self .args_delta }
423
- part = replace (part , args = ArgsDict ( updated_dict ) )
397
+ updated_dict = {** (part .args or {}), ** self .args_delta }
398
+ part = replace (part , args = updated_dict )
424
399
425
400
if self .tool_call_id :
426
401
# Replace the tool_call_id entirely if given
0 commit comments