Skip to content

Commit 95b80fa

Browse files
shaheerzamanDouweM
andauthored
Support models that return output tool args as {"response': "<JSON string>"} (#2836)
Co-authored-by: Douwe Maan <[email protected]>
1 parent e619d5e commit 95b80fa

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
99

10-
from pydantic import TypeAdapter, ValidationError
10+
from pydantic import Json, TypeAdapter, ValidationError
1111
from pydantic_core import SchemaValidator, to_json
1212
from typing_extensions import Self, TypedDict, TypeVar, assert_never
1313

@@ -624,21 +624,33 @@ def __init__(
624624
json_schema = self._function_schema.json_schema
625625
json_schema['description'] = self._function_schema.description
626626
else:
627-
type_adapter: TypeAdapter[Any]
627+
json_schema_type_adapter: TypeAdapter[Any]
628+
validation_type_adapter: TypeAdapter[Any]
628629
if _utils.is_model_like(output):
629-
type_adapter = TypeAdapter(output)
630+
json_schema_type_adapter = validation_type_adapter = TypeAdapter(output)
630631
else:
631632
self.outer_typed_dict_key = 'response'
633+
output_type: type[OutputDataT] = cast(type[OutputDataT], output)
634+
632635
response_data_typed_dict = TypedDict( # noqa: UP013
633636
'response_data_typed_dict',
634-
{'response': cast(type[OutputDataT], output)}, # pyright: ignore[reportInvalidTypeForm]
637+
{'response': output_type}, # pyright: ignore[reportInvalidTypeForm]
638+
)
639+
json_schema_type_adapter = TypeAdapter(response_data_typed_dict)
640+
641+
# More lenient validator: allow either the native type or a JSON string containing it
642+
# i.e. `response: OutputDataT | Json[OutputDataT]`, as some models don't follow the schema correctly,
643+
# e.g. `BedrockConverseModel('us.meta.llama3-2-11b-instruct-v1:0')`
644+
response_validation_typed_dict = TypedDict( # noqa: UP013
645+
'response_validation_typed_dict',
646+
{'response': output_type | Json[output_type]}, # pyright: ignore[reportInvalidTypeForm]
635647
)
636-
type_adapter = TypeAdapter(response_data_typed_dict)
648+
validation_type_adapter = TypeAdapter(response_validation_typed_dict)
637649

638650
# Really a PluggableSchemaValidator, but it's API-compatible
639-
self.validator = cast(SchemaValidator, type_adapter.validator)
651+
self.validator = cast(SchemaValidator, validation_type_adapter.validator)
640652
json_schema = _utils.check_object_json_schema(
641-
type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
653+
json_schema_type_adapter.json_schema(schema_generator=GenerateToolJsonSchema)
642654
)
643655

644656
if self.outer_typed_dict_key:

tests/test_agent.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,37 @@ def return_tuple(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
7575
assert result.output == ('foo', 'bar')
7676

7777

78+
class Person(BaseModel):
79+
name: str
80+
81+
82+
def test_result_list_of_models_with_stringified_response():
83+
def return_list(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
84+
assert info.output_tools is not None
85+
# Simulate providers that return the nested payload as a JSON string under "response"
86+
args_json = json.dumps(
87+
{
88+
'response': json.dumps(
89+
[
90+
{'name': 'John Doe'},
91+
{'name': 'Jane Smith'},
92+
]
93+
)
94+
}
95+
)
96+
return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)])
97+
98+
agent = Agent(FunctionModel(return_list), output_type=list[Person])
99+
100+
result = agent.run_sync('Hello')
101+
assert result.output == snapshot(
102+
[
103+
Person(name='John Doe'),
104+
Person(name='Jane Smith'),
105+
]
106+
)
107+
108+
78109
class Foo(BaseModel):
79110
a: int
80111
b: str

0 commit comments

Comments
 (0)