Skip to content

Commit 7997827

Browse files
authored
Remove ArgsDict and ArgsJson (#769)
1 parent bb5f740 commit 7997827

29 files changed

+269
-382
lines changed

docs/agents.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ with capture_run_messages() as messages: # (2)!
463463
parts=[
464464
ToolCallPart(
465465
tool_name='calc_volume',
466-
args=ArgsDict(args_dict={'size': 6}),
466+
args={'size': 6},
467467
tool_call_id=None,
468468
part_kind='tool-call',
469469
)
@@ -488,7 +488,7 @@ with capture_run_messages() as messages: # (2)!
488488
parts=[
489489
ToolCallPart(
490490
tool_name='calc_volume',
491-
args=ArgsDict(args_dict={'size': 6}),
491+
args={'size': 6},
492492
tool_call_id=None,
493493
part_kind='tool-call',
494494
)

docs/testing-evals.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ from dirty_equals import IsNow
9898
from pydantic_ai import models, capture_run_messages
9999
from pydantic_ai.models.test import TestModel
100100
from pydantic_ai.messages import (
101-
ArgsDict,
102101
ModelResponse,
103102
SystemPromptPart,
104103
TextPart,
@@ -142,12 +141,10 @@ async def test_forecast():
142141
parts=[
143142
ToolCallPart(
144143
tool_name='weather_forecast',
145-
args=ArgsDict(
146-
args_dict={
147-
'location': 'a',
148-
'forecast_date': '2024-01-01', # (8)!
149-
}
150-
),
144+
args={
145+
'location': 'a',
146+
'forecast_date': '2024-01-01', # (8)!
147+
},
151148
tool_call_id=None,
152149
)
153150
],
@@ -223,9 +220,7 @@ def call_weather_forecast( # (1)!
223220
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
224221
assert m is not None
225222
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
226-
return ModelResponse(
227-
parts=[ToolCallPart.from_raw_args('weather_forecast', args)]
228-
)
223+
return ModelResponse(parts=[ToolCallPart('weather_forecast', args)])
229224
else:
230225
# second call, return the forecast
231226
msg = messages[-1].parts[0]

docs/tools.md

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ print(dice_result.all_messages())
8686
ModelResponse(
8787
parts=[
8888
ToolCallPart(
89-
tool_name='roll_die',
90-
args=ArgsDict(args_dict={}),
91-
tool_call_id=None,
92-
part_kind='tool-call',
89+
tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call'
9390
)
9491
],
9592
model_name='function:model_logic',
@@ -112,7 +109,7 @@ print(dice_result.all_messages())
112109
parts=[
113110
ToolCallPart(
114111
tool_name='get_player_name',
115-
args=ArgsDict(args_dict={}),
112+
args={},
116113
tool_call_id=None,
117114
part_kind='tool-call',
118115
)

pydantic_ai_slim/pydantic_ai/_parts_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def handle_tool_call_part(
221221
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
222222
has been added to the manager, or replaced an existing part.
223223
"""
224-
new_part = ToolCallPart.from_raw_args(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
224+
new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
225225
if vendor_part_id is None:
226226
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
227227
new_part_index = len(self._parts)

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -201,14 +201,10 @@ def validate(
201201
"""
202202
try:
203203
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
204-
if isinstance(tool_call.args, _messages.ArgsJson):
205-
result = self.type_adapter.validate_json(
206-
tool_call.args.args_json or '', experimental_allow_partial=pyd_allow_partial
207-
)
204+
if isinstance(tool_call.args, str):
205+
result = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial)
208206
else:
209-
result = self.type_adapter.validate_python(
210-
tool_call.args.args_dict, experimental_allow_partial=pyd_allow_partial
211-
)
207+
result = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial)
212208
except ValidationError as e:
213209
if wrap_validation_errors:
214210
m = _messages.RetryPromptPart(

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pydantic
88
import pydantic_core
9-
from typing_extensions import Self, assert_never
109

1110
from ._utils import now_utc as _now_utc
1211
from .exceptions import UnexpectedModelBehavior
@@ -168,33 +167,17 @@ def has_content(self) -> bool:
168167
return bool(self.content)
169168

170169

171-
@dataclass
172-
class ArgsJson:
173-
"""Tool arguments as a JSON string."""
174-
175-
args_json: str
176-
"""A JSON string of arguments."""
177-
178-
179-
@dataclass
180-
class ArgsDict:
181-
"""Tool arguments as a Python dictionary."""
182-
183-
args_dict: dict[str, Any]
184-
"""A python dictionary of arguments."""
185-
186-
187170
@dataclass
188171
class ToolCallPart:
189172
"""A tool call from a model."""
190173

191174
tool_name: str
192175
"""The name of the tool to call."""
193176

194-
args: ArgsJson | ArgsDict
177+
args: str | dict[str, Any]
195178
"""The arguments to pass to the tool.
196179
197-
Either as JSON or a Python dictionary depending on how data was returned.
180+
This is stored either as a JSON string or a Python dictionary depending on how data was received.
198181
"""
199182

200183
tool_call_id: str | None = None
@@ -203,24 +186,14 @@ class ToolCallPart:
203186
part_kind: Literal['tool-call'] = 'tool-call'
204187
"""Part type identifier, this is available on all parts as a discriminator."""
205188

206-
@classmethod
207-
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
208-
"""Create a `ToolCallPart` from raw arguments, converting them to `ArgsJson` or `ArgsDict`."""
209-
if isinstance(args, str):
210-
return cls(tool_name, ArgsJson(args), tool_call_id)
211-
elif isinstance(args, dict):
212-
return cls(tool_name, ArgsDict(args), tool_call_id)
213-
else:
214-
assert_never(args)
215-
216189
def args_as_dict(self) -> dict[str, Any]:
217190
"""Return the arguments as a Python dictionary.
218191
219192
This is just for convenience with models that require dicts as input.
220193
"""
221-
if isinstance(self.args, ArgsDict):
222-
return self.args.args_dict
223-
args = pydantic_core.from_json(self.args.args_json)
194+
if isinstance(self.args, dict):
195+
return self.args
196+
args = pydantic_core.from_json(self.args)
224197
assert isinstance(args, dict), 'args should be a dict'
225198
return cast(dict[str, Any], args)
226199

@@ -229,16 +202,18 @@ def args_as_json_str(self) -> str:
229202
230203
This is just for convenience with models that require JSON strings as input.
231204
"""
232-
if isinstance(self.args, ArgsJson):
233-
return self.args.args_json
234-
return pydantic_core.to_json(self.args.args_dict).decode()
205+
if isinstance(self.args, str):
206+
return self.args
207+
return pydantic_core.to_json(self.args).decode()
235208

236209
def has_content(self) -> bool:
237210
"""Return `True` if the arguments contain any data."""
238-
if isinstance(self.args, ArgsDict):
239-
return any(self.args.args_dict.values())
211+
if isinstance(self.args, dict):
212+
# TODO: This should probably return True if you have the value False, or 0, etc.
213+
# It makes sense to me to ignore empty strings, but not sure about empty lists or dicts
214+
return any(self.args.values())
240215
else:
241-
return bool(self.args.args_json)
216+
return bool(self.args)
242217

243218

244219
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
@@ -331,7 +306,7 @@ def as_part(self) -> ToolCallPart | None:
331306
if self.tool_name_delta is None or self.args_delta is None:
332307
return None
333308

334-
return ToolCallPart.from_raw_args(
309+
return ToolCallPart(
335310
self.tool_name_delta,
336311
self.args_delta,
337312
self.tool_call_id,
@@ -396,7 +371,7 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPa
396371

397372
# If we now have enough data to create a full ToolCallPart, do so
398373
if delta.tool_name_delta is not None and delta.args_delta is not None:
399-
return ToolCallPart.from_raw_args(
374+
return ToolCallPart(
400375
delta.tool_name_delta,
401376
delta.args_delta,
402377
delta.tool_call_id,
@@ -412,15 +387,15 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
412387
part = replace(part, tool_name=tool_name)
413388

414389
if isinstance(self.args_delta, str):
415-
if not isinstance(part.args, ArgsJson):
390+
if not isinstance(part.args, str):
416391
raise UnexpectedModelBehavior(f'Cannot apply JSON deltas to non-JSON tool arguments ({part=}, {self=})')
417-
updated_json = part.args.args_json + self.args_delta
418-
part = replace(part, args=ArgsJson(updated_json))
392+
updated_json = part.args + self.args_delta
393+
part = replace(part, args=updated_json)
419394
elif isinstance(self.args_delta, dict):
420-
if not isinstance(part.args, ArgsDict):
395+
if not isinstance(part.args, dict):
421396
raise UnexpectedModelBehavior(f'Cannot apply dict deltas to non-dict tool arguments ({part=}, {self=})')
422-
updated_dict = {**(part.args.args_dict or {}), **self.args_delta}
423-
part = replace(part, args=ArgsDict(updated_dict))
397+
updated_dict = {**(part.args or {}), **self.args_delta}
398+
part = replace(part, args=updated_dict)
424399

425400
if self.tool_call_id:
426401
# Replace the tool_call_id entirely if given

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .. import UnexpectedModelBehavior, _utils, usage
1414
from .._utils import guard_tool_call_id as _guard_tool_call_id
1515
from ..messages import (
16-
ArgsDict,
1716
ModelMessage,
1817
ModelRequest,
1918
ModelResponse,
@@ -242,7 +241,7 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
242241
else:
243242
assert isinstance(item, ToolUseBlock), 'unexpected item type'
244243
items.append(
245-
ToolCallPart.from_raw_args(
244+
ToolCallPart(
246245
tool_name=item.name,
247246
args=cast(dict[str, Any], item.input),
248247
tool_call_id=item.id,
@@ -319,7 +318,6 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]]
319318

320319

321320
def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
322-
assert isinstance(t.args, ArgsDict), f'Expected ArgsDict, got {t.args}'
323321
return ToolUseBlockParam(
324322
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
325323
type='tool_use',

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _process_response(self, response: ChatResponse) -> ModelResponse:
191191
for c in response.message.tool_calls or []:
192192
if c.function and c.function.name and c.function.arguments:
193193
parts.append(
194-
ToolCallPart.from_raw_args(
194+
ToolCallPart(
195195
tool_name=c.function.name,
196196
args=c.function.arguments,
197197
tool_call_id=c.id,

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _process_response_from_parts(
453453
items.append(TextPart(content=part['text']))
454454
elif 'function_call' in part:
455455
items.append(
456-
ToolCallPart.from_raw_args(
456+
ToolCallPart(
457457
tool_name=part['function_call']['name'],
458458
args=part['function_call']['args'],
459459
)

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,7 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
224224
items.append(TextPart(content=choice.message.content))
225225
if choice.message.tool_calls is not None:
226226
for c in choice.message.tool_calls:
227-
items.append(
228-
ToolCallPart.from_raw_args(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)
229-
)
227+
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
230228
return ModelResponse(items, model_name=self.model_name, timestamp=timestamp)
231229

232230
async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> GroqStreamedResponse:

0 commit comments

Comments
 (0)