Skip to content

Commit a2241c2

Browse files
authored
Make args handling more robust (#489)
1 parent a6260fc commit a2241c2

File tree

16 files changed

+104
-108
lines changed

16 files changed

+104
-108
lines changed

docs/testing-evals.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ def call_weather_forecast( # (1)!
219219
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
220220
assert m is not None
221221
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
222-
return ModelResponse(parts=[ToolCallPart.from_dict('weather_forecast', args)])
222+
return ModelResponse(
223+
parts=[ToolCallPart.from_raw_args('weather_forecast', args)]
224+
)
223225
else:
224226
# second call, return the forecast
225227
msg = messages[-1].parts[0]

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from dataclasses import dataclass, field
44
from datetime import datetime
5-
from typing import Annotated, Any, Literal, Union
5+
from typing import Annotated, Any, Literal, Union, cast
66

77
import pydantic
88
import pydantic_core
9-
from typing_extensions import Self
9+
from typing_extensions import Self, assert_never
1010

1111
from ._utils import now_utc as _now_utc
1212

@@ -190,12 +190,34 @@ class ToolCallPart:
190190
"""Part type identifier, this is available on all parts as a discriminator."""
191191

192192
@classmethod
193-
def from_json(cls, tool_name: str, args_json: str, tool_call_id: str | None = None) -> Self:
194-
return cls(tool_name, ArgsJson(args_json), tool_call_id)
193+
def from_raw_args(cls, tool_name: str, args: str | dict[str, Any], tool_call_id: str | None = None) -> Self:
194+
"""Create a `ToolCallPart` from raw arguments."""
195+
if isinstance(args, str):
196+
return cls(tool_name, ArgsJson(args), tool_call_id)
197+
elif isinstance(args, dict):
198+
return cls(tool_name, ArgsDict(args), tool_call_id)
199+
else:
200+
assert_never(args)
195201

196-
@classmethod
197-
def from_dict(cls, tool_name: str, args_dict: dict[str, Any], tool_call_id: str | None = None) -> Self:
198-
return cls(tool_name, ArgsDict(args_dict), tool_call_id)
202+
def args_as_dict(self) -> dict[str, Any]:
203+
"""Return the arguments as a Python dictionary.
204+
205+
This is just for convenience with models that require dicts as input.
206+
"""
207+
if isinstance(self.args, ArgsDict):
208+
return self.args.args_dict
209+
args = pydantic_core.from_json(self.args.args_json)
210+
assert isinstance(args, dict), 'args should be a dict'
211+
return cast(dict[str, Any], args)
212+
213+
def args_as_json_str(self) -> str:
214+
"""Return the arguments as a JSON string.
215+
216+
This is just for convenience with models that require JSON strings as input.
217+
"""
218+
if isinstance(self.args, ArgsJson):
219+
return self.args.args_json
220+
return pydantic_core.to_json(self.args.args_dict).decode()
199221

200222
def has_content(self) -> bool:
201223
if isinstance(self.args, ArgsDict):

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def _process_response(response: AnthropicMessage) -> ModelResponse:
220220
else:
221221
assert isinstance(item, ToolUseBlock), 'unexpected item type'
222222
items.append(
223-
ToolCallPart.from_dict(
223+
ToolCallPart.from_raw_args(
224224
item.name,
225225
cast(dict[str, Any], item.input),
226226
item.id,
@@ -311,7 +311,7 @@ def _map_tool_call(t: ToolCallPart) -> ToolUseBlockParam:
311311
id=_guard_tool_call_id(t=t, model_source='Anthropic'),
312312
type='tool_use',
313313
name=t.tool_name,
314-
input=t.args.args_dict,
314+
input=t.args_as_dict(),
315315
)
316316

317317

pydantic_ai_slim/pydantic_ai/models/function.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@
99
from itertools import chain
1010
from typing import Callable, Union, cast
1111

12-
import pydantic_core
1312
from typing_extensions import TypeAlias, assert_never, overload
1413

1514
from .. import _utils, result
1615
from ..messages import (
17-
ArgsJson,
1816
ModelMessage,
1917
ModelRequest,
2018
ModelResponse,
@@ -232,7 +230,7 @@ def get(self, *, final: bool = False) -> ModelResponse:
232230
calls: list[ModelResponsePart] = []
233231
for c in self._delta_tool_calls.values():
234232
if c.name is not None and c.json_args is not None:
235-
calls.append(ToolCallPart.from_json(c.name, c.json_args))
233+
calls.append(ToolCallPart.from_raw_args(c.name, c.json_args))
236234

237235
return ModelResponse(calls, timestamp=self._timestamp)
238236

@@ -268,11 +266,7 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage:
268266
response_tokens += _estimate_string_usage(part.content)
269267
elif isinstance(part, ToolCallPart):
270268
call = part
271-
if isinstance(call.args, ArgsJson):
272-
args_str = call.args.args_json
273-
else:
274-
args_str = pydantic_core.to_json(call.args.args_dict).decode()
275-
response_tokens += 1 + _estimate_string_usage(args_str)
269+
response_tokens += 1 + _estimate_string_usage(call.args_as_json_str())
276270
else:
277271
assert_never(part)
278272
else:

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from .. import UnexpectedModelBehavior, _utils, exceptions, result
1818
from ..messages import (
19-
ArgsDict,
2019
ModelMessage,
2120
ModelRequest,
2221
ModelResponse,
@@ -460,8 +459,7 @@ class _GeminiFunctionCallPart(TypedDict):
460459

461460

462461
def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart:
463-
assert isinstance(tool.args, ArgsDict), f'Expected ArgsObject, got {tool.args}'
464-
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args.args_dict))
462+
return _GeminiFunctionCallPart(function_call=_GeminiFunctionCall(name=tool.tool_name, args=tool.args_as_dict()))
465463

466464

467465
def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: datetime | None = None) -> ModelResponse:
@@ -470,7 +468,7 @@ def _process_response_from_parts(parts: Sequence[_GeminiPartUnion], timestamp: d
470468
if 'text' in part:
471469
items.append(TextPart(part['text']))
472470
elif 'function_call' in part:
473-
items.append(ToolCallPart.from_dict(part['function_call']['name'], part['function_call']['args']))
471+
items.append(ToolCallPart.from_raw_args(part['function_call']['name'], part['function_call']['args']))
474472
elif 'function_response' in part:
475473
raise exceptions.UnexpectedModelBehavior(
476474
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .. import UnexpectedModelBehavior, _utils, result
1414
from .._utils import guard_tool_call_id as _guard_tool_call_id
1515
from ..messages import (
16-
ArgsJson,
1716
ModelMessage,
1817
ModelRequest,
1918
ModelResponse,
@@ -221,7 +220,7 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
221220
items.append(TextPart(choice.message.content))
222221
if choice.message.tool_calls is not None:
223222
for c in choice.message.tool_calls:
224-
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
223+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
225224
return ModelResponse(items, timestamp=timestamp)
226225

227226
@staticmethod
@@ -380,7 +379,7 @@ def get(self, *, final: bool = False) -> ModelResponse:
380379
for c in self._delta_tool_calls.values():
381380
if f := c.function:
382381
if f.name is not None and f.arguments is not None:
383-
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
382+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
384383

385384
return ModelResponse(items, timestamp=self._timestamp)
386385

@@ -392,11 +391,10 @@ def timestamp(self) -> datetime:
392391

393392

394393
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
395-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
396394
return chat.ChatCompletionMessageToolCallParam(
397395
id=_guard_tool_call_id(t=t, model_source='Groq'),
398396
type='function',
399-
function={'name': t.tool_name, 'arguments': t.args.args_json},
397+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
400398
)
401399

402400

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -560,13 +560,10 @@ def get(self, *, final: bool = False) -> ModelResponse:
560560
# when `{"response":` then `repair_json` sets `{"response": ""}` (type not found default str)
561561
# when `{"response": {` then `repair_json` sets `{"response": {}}` (type found)
562562
# This ensures it's corrected to `{"response": {}}` and other required parameters and type.
563-
if not self._validate_required_json_shema(output_json, result_tool.parameters_json_schema):
563+
if not self._validate_required_json_schema(output_json, result_tool.parameters_json_schema):
564564
continue
565565

566-
tool = ToolCallPart.from_dict(
567-
tool_name=result_tool.name,
568-
args_dict=output_json,
569-
)
566+
tool = ToolCallPart.from_raw_args(result_tool.name, output_json)
570567
calls.append(tool)
571568

572569
return ModelResponse(calls, timestamp=self._timestamp)
@@ -578,7 +575,7 @@ def timestamp(self) -> datetime:
578575
return self._timestamp
579576

580577
@staticmethod
581-
def _validate_required_json_shema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
578+
def _validate_required_json_schema(json_dict: dict[str, Any], json_schema: dict[str, Any]) -> bool:
582579
"""Validate that all required parameters in the JSON schema are present in the JSON dictionary."""
583580
required_params = json_schema.get('required', [])
584581
properties = json_schema.get('properties', {})
@@ -602,7 +599,7 @@ def _validate_required_json_shema(json_dict: dict[str, Any], json_schema: dict[s
602599

603600
if isinstance(json_dict[param], dict) and 'properties' in param_schema:
604601
nested_schema = param_schema
605-
if not MistralStreamStructuredResponse._validate_required_json_shema(json_dict[param], nested_schema):
602+
if not MistralStreamStructuredResponse._validate_required_json_schema(json_dict[param], nested_schema):
606603
return False
607604

608605
return True
@@ -633,16 +630,7 @@ def _map_mistral_to_pydantic_tool_call(tool_call: MistralToolCall) -> ToolCallPa
633630
tool_call_id = tool_call.id or None
634631
func_call = tool_call.function
635632

636-
if isinstance(func_call.arguments, str):
637-
return ToolCallPart.from_json(
638-
tool_name=func_call.name,
639-
args_json=func_call.arguments,
640-
tool_call_id=tool_call_id,
641-
)
642-
else:
643-
return ToolCallPart.from_dict(
644-
tool_name=func_call.name, args_dict=func_call.arguments, tool_call_id=tool_call_id
645-
)
633+
return ToolCallPart.from_raw_args(func_call.name, func_call.arguments, tool_call_id)
646634

647635

648636
def _map_usage(response: MistralChatCompletionResponse | MistralCompletionChunk) -> Usage:

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .. import UnexpectedModelBehavior, _utils, result
1414
from .._utils import guard_tool_call_id as _guard_tool_call_id
1515
from ..messages import (
16-
ArgsJson,
1716
ModelMessage,
1817
ModelRequest,
1918
ModelResponse,
@@ -211,7 +210,7 @@ def _process_response(response: chat.ChatCompletion) -> ModelResponse:
211210
items.append(TextPart(choice.message.content))
212211
if choice.message.tool_calls is not None:
213212
for c in choice.message.tool_calls:
214-
items.append(ToolCallPart.from_json(c.function.name, c.function.arguments, c.id))
213+
items.append(ToolCallPart.from_raw_args(c.function.name, c.function.arguments, c.id))
215214
return ModelResponse(items, timestamp=timestamp)
216215

217216
@staticmethod
@@ -372,7 +371,7 @@ def get(self, *, final: bool = False) -> ModelResponse:
372371
for c in self._delta_tool_calls.values():
373372
if f := c.function:
374373
if f.name is not None and f.arguments is not None:
375-
items.append(ToolCallPart.from_json(f.name, f.arguments, c.id))
374+
items.append(ToolCallPart.from_raw_args(f.name, f.arguments, c.id))
376375

377376
return ModelResponse(items, timestamp=self._timestamp)
378377

@@ -384,11 +383,10 @@ def timestamp(self) -> datetime:
384383

385384

386385
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
387-
assert isinstance(t.args, ArgsJson), f'Expected ArgsJson, got {t.args}'
388386
return chat.ChatCompletionMessageToolCallParam(
389387
id=_guard_tool_call_id(t=t, model_source='OpenAI'),
390388
type='function',
391-
function={'name': t.tool_name, 'arguments': t.args.args_json},
389+
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
392390
)
393391

394392

pydantic_ai_slim/pydantic_ai/models/test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
167167
# if there are tools, the first thing we want to do is call all of them
168168
if self.tool_calls and not any(isinstance(m, ModelResponse) for m in messages):
169169
return ModelResponse(
170-
parts=[ToolCallPart.from_dict(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
170+
parts=[ToolCallPart.from_raw_args(name, self.gen_tool_args(args)) for name, args in self.tool_calls]
171171
)
172172

173173
if messages:
@@ -179,7 +179,7 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
179179
if new_retry_names:
180180
return ModelResponse(
181181
parts=[
182-
ToolCallPart.from_dict(name, self.gen_tool_args(args))
182+
ToolCallPart.from_raw_args(name, self.gen_tool_args(args))
183183
for name, args in self.tool_calls
184184
if name in new_retry_names
185185
]
@@ -205,10 +205,10 @@ def _request(self, messages: list[ModelMessage], model_settings: ModelSettings |
205205
custom_result_args = self.result.right
206206
result_tool = self.result_tools[self.seed % len(self.result_tools)]
207207
if custom_result_args is not None:
208-
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, custom_result_args)])
208+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, custom_result_args)])
209209
else:
210210
response_args = self.gen_tool_args(result_tool)
211-
return ModelResponse(parts=[ToolCallPart.from_dict(result_tool.name, response_args)])
211+
return ModelResponse(parts=[ToolCallPart.from_raw_args(result_tool.name, response_args)])
212212

213213

214214
@dataclass

tests/models/test_gemini.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
413413
async def test_request_structured_response(get_gemini_client: GetGeminiClient):
414414
response = gemini_response(
415415
_content_model_response(
416-
ModelResponse(parts=[ToolCallPart.from_dict('final_result', {'response': [1, 2, 123]})])
416+
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2, 123]})])
417417
)
418418
)
419419
gemini_client = get_gemini_client(response)
@@ -449,12 +449,12 @@ async def test_request_tool_call(get_gemini_client: GetGeminiClient):
449449
responses = [
450450
gemini_response(
451451
_content_model_response(
452-
ModelResponse(parts=[ToolCallPart.from_dict('get_location', {'loc_name': 'San Fransisco'})])
452+
ModelResponse(parts=[ToolCallPart.from_raw_args('get_location', {'loc_name': 'San Fransisco'})])
453453
)
454454
),
455455
gemini_response(
456456
_content_model_response(
457-
ModelResponse(parts=[ToolCallPart.from_dict('get_location', {'loc_name': 'London'})])
457+
ModelResponse(parts=[ToolCallPart.from_raw_args('get_location', {'loc_name': 'London'})])
458458
)
459459
),
460460
gemini_response(_content_model_response(ModelResponse.from_text('final response'))),
@@ -596,7 +596,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient):
596596
responses = [
597597
gemini_response(
598598
_content_model_response(
599-
ModelResponse(parts=[ToolCallPart.from_dict('final_result', {'response': [1, 2]})])
599+
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2]})])
600600
),
601601
),
602602
]
@@ -615,10 +615,10 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient):
615615
async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
616616
first_responses = [
617617
gemini_response(
618-
_content_model_response(ModelResponse(parts=[ToolCallPart.from_dict('foo', {'x': 'a'})])),
618+
_content_model_response(ModelResponse(parts=[ToolCallPart.from_raw_args('foo', {'x': 'a'})])),
619619
),
620620
gemini_response(
621-
_content_model_response(ModelResponse(parts=[ToolCallPart.from_dict('bar', {'y': 'b'})])),
621+
_content_model_response(ModelResponse(parts=[ToolCallPart.from_raw_args('bar', {'y': 'b'})])),
622622
),
623623
]
624624
d1 = _gemini_streamed_response_ta.dump_json(first_responses, by_alias=True)
@@ -627,7 +627,7 @@ async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
627627
second_responses = [
628628
gemini_response(
629629
_content_model_response(
630-
ModelResponse(parts=[ToolCallPart.from_dict('final_result', {'response': [1, 2]})])
630+
ModelResponse(parts=[ToolCallPart.from_raw_args('final_result', {'response': [1, 2]})])
631631
),
632632
),
633633
]

0 commit comments

Comments
 (0)