Skip to content
Draft
52 changes: 50 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydantic_ai._instrumentation import InstrumentationNames

from . import _function_schema, _utils, messages as _messages
from ._json_schema import JsonSchema
from ._run_context import AgentDepsT, RunContext
from .exceptions import ModelRetry, ToolRetryError, UserError
from .output import (
Expand Down Expand Up @@ -226,6 +227,33 @@ def mode(self) -> OutputMode:
def allows_text(self) -> bool:
return self.text_processor is not None

def json_schema(self) -> JsonSchema:
additional_outputs: Sequence[type] = []
if self.allows_text:
additional_outputs.append(str)
if self.allows_deferred_tools:
additional_outputs.append(DeferredToolRequests)
if self.allows_image:
additional_outputs.append(_messages.BinaryImage)

processors = None
# if processor exists, it should override toolsets
if hasattr(self, 'processor'):
processors = {self.processor.object_def.name: self.processor}
elif self.toolset:
processors = self.toolset.processors

# special case where we don't want to union
if len(additional_outputs) == 1 and not processors:
return TypeAdapter(additional_outputs[0]).json_schema()

json_schema = UnionOutputProcessor(
outputs=additional_outputs,
processors=processors,
).object_def.json_schema

return json_schema

@classmethod
def build( # noqa: C901
cls,
Expand Down Expand Up @@ -395,7 +423,7 @@ def __init__(
super().__init__(
toolset=toolset,
object_def=processor.object_def,
text_processor=processor,
text_processor=processor, # always triggers allows_text to be true, maybe a bug?
allows_deferred_tools=allows_deferred_tools,
allows_image=allows_image,
)
Expand Down Expand Up @@ -693,16 +721,36 @@ def __init__(
name: str | None = None,
description: str | None = None,
strict: bool | None = None,
processors: dict[str, ObjectOutputProcessor[OutputDataT]] | None = None,
):
self._union_processor = ObjectOutputProcessor(output=UnionOutputModel)

json_schemas: list[ObjectJsonSchema] = []
self._processors = {}

if processors:
for name, processor in processors.items():
object_key = name
i = 1
original_key = object_key
while object_key in self._processors:
i += 1
object_key = f'{original_key}_{i}'
self._processors[object_key] = processor

object_def = processor.object_def
json_schema = object_def.json_schema
if object_def.name: # pragma: no branch
json_schema['title'] = name
if object_def.description:
json_schema['description'] = object_def.description
json_schemas.append(json_schema)

for output in outputs:
processor = ObjectOutputProcessor(output=output, strict=strict)
object_def = processor.object_def

object_key = object_def.name or output.__name__

i = 1
original_key = object_key
while object_key in self._processors:
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UserPromptNode,
capture_run_messages,
)
from .._json_schema import JsonSchema
from .._output import OutputToolset
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
Expand Down Expand Up @@ -947,6 +948,11 @@ def decorator(
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
return func

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
output_schema = self._prepare_output_schema(output_type)
return output_schema.json_schema()

@overload
def output_validator(
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
result,
usage as _usage,
)
from .._json_schema import JsonSchema
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -122,6 +123,11 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
"""
raise NotImplementedError

@abstractmethod
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
raise NotImplementedError

@overload
async def run(
self,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
models,
usage as _usage,
)
from .._json_schema import JsonSchema
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
from ..run import AgentRun
Expand Down Expand Up @@ -67,6 +68,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
async def __aexit__(self, *args: Any) -> bool | None:
return await self.wrapped.__aexit__(*args)

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
return self.wrapped.output_json_schema(output_type=output_type)

@overload
def iter(
self,
Expand Down
42 changes: 42 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5242,6 +5242,48 @@ def foo() -> str:
assert wrapper_agent.name == 'wrapped'
assert wrapper_agent.output_type == agent.output_type
assert wrapper_agent.event_stream_handler == agent.event_stream_handler
assert wrapper_agent.output_json_schema() == snapshot(
{
'type': 'object',
'properties': {
'result': {
'anyOf': [
{
'type': 'object',
'properties': {
'kind': {'type': 'string', 'const': 'Foo'},
'data': {
'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}},
'required': ['a', 'b'],
'type': 'object',
},
},
'required': ['kind', 'data'],
'additionalProperties': False,
'title': 'Foo',
},
{
'type': 'object',
'properties': {
'kind': {'type': 'string', 'const': 'str'},
'data': {
'properties': {'response': {'type': 'string'}},
'required': ['response'],
'type': 'object',
},
},
'required': ['kind', 'data'],
'additionalProperties': False,
'title': 'str',
},
]
}
},
'required': ['result'],
'additionalProperties': False,
}
)
assert wrapper_agent.output_json_schema(output_type=str) == snapshot({'type': 'string'})

bar_toolset = FunctionToolset()

Expand Down
Loading
Loading