diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 053d3a71a8..c97acd420d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -15,6 +15,7 @@ from pydantic_ai._instrumentation import InstrumentationNames from . import _function_schema, _utils, messages as _messages +from ._json_schema import JsonSchema from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( @@ -226,6 +227,33 @@ def mode(self) -> OutputMode: def allows_text(self) -> bool: return self.text_processor is not None + def json_schema(self) -> JsonSchema: + additional_outputs: Sequence[type] = [] + if self.allows_text: + additional_outputs.append(str) + if self.allows_deferred_tools: + additional_outputs.append(DeferredToolRequests) + if self.allows_image: + additional_outputs.append(_messages.BinaryImage) + + processors = None + # if processor exists, it should override toolsets + if hasattr(self, 'processor'): + processors = {self.processor.object_def.name: self.processor} + elif self.toolset: + processors = self.toolset.processors + + # special case where we don't want to union + if len(additional_outputs) == 1 and not processors: + return TypeAdapter(additional_outputs[0]).json_schema() + + json_schema = UnionOutputProcessor( + outputs=additional_outputs, + processors=processors, + ).object_def.json_schema + + return json_schema + @classmethod def build( # noqa: C901 cls, @@ -395,7 +423,7 @@ def __init__( super().__init__( toolset=toolset, object_def=processor.object_def, - text_processor=processor, + text_processor=processor, # always triggers allows_text to be true, maybe a bug? allows_deferred_tools=allows_deferred_tools, allows_image=allows_image, ) @@ -693,16 +721,36 @@ def __init__( name: str | None = None, description: str | None = None, strict: bool | None = None, + processors: dict[str, ObjectOutputProcessor[OutputDataT]] | None = None, ): self._union_processor = ObjectOutputProcessor(output=UnionOutputModel) json_schemas: list[ObjectJsonSchema] = [] self._processors = {} + + if processors: + for name, processor in processors.items(): + object_key = name + i = 1 + original_key = object_key + while object_key in self._processors: + i += 1 + object_key = f'{original_key}_{i}' + self._processors[object_key] = processor + + object_def = processor.object_def + json_schema = object_def.json_schema + if object_def.name: # pragma: no branch + json_schema['title'] = name + if object_def.description: + json_schema['description'] = object_def.description + json_schemas.append(json_schema) + for output in outputs: processor = ObjectOutputProcessor(output=output, strict=strict) object_def = processor.object_def - object_key = object_def.name or output.__name__ + i = 1 original_key = object_key while object_key in self._processors: diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..4670a59afe 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -34,6 +34,7 @@ UserPromptNode, capture_run_messages, ) +from .._json_schema import JsonSchema from .._output import OutputToolset from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool @@ -947,6 +948,11 @@ def decorator( self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic)) return func + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + """The output JSON schema.""" + output_schema = self._prepare_output_schema(output_type) + return output_schema.json_schema() + @overload def output_validator( self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], / diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index c7c1cb2b5c..8c7c81ee41 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -23,6 +23,7 @@ result, usage as _usage, ) +from .._json_schema import JsonSchema from .._tool_manager import ToolManager from ..builtin_tools import AbstractBuiltinTool from ..output import OutputDataT, OutputSpec @@ -122,6 +123,11 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]: """ raise NotImplementedError + @abstractmethod + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + """The output JSON schema.""" + raise NotImplementedError + @overload async def run( self, diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 38e832fa2b..6431be7297 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -10,6 +10,7 @@ models, usage as _usage, ) +from .._json_schema import JsonSchema from ..builtin_tools import AbstractBuiltinTool from ..output import OutputDataT, OutputSpec from ..run import AgentRun @@ -67,6 +68,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]: async def __aexit__(self, *args: Any) -> bool | None: return await self.wrapped.__aexit__(*args) + def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema: + return self.wrapped.output_json_schema(output_type=output_type) + @overload def iter( self, diff --git a/tests/test_agent.py b/tests/test_agent.py index 16ff3b2d4c..c35271ae5f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -5242,6 +5242,48 @@ def foo() -> str: assert wrapper_agent.name == 'wrapped' assert wrapper_agent.output_type == agent.output_type assert wrapper_agent.event_stream_handler == agent.event_stream_handler + assert wrapper_agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'Foo'}, + 'data': { + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'Foo', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + assert wrapper_agent.output_json_schema(output_type=str) == snapshot({'type': 'string'}) bar_toolset = FunctionToolset() diff --git a/tests/test_agent_output_schemas.py b/tests/test_agent_output_schemas.py new file mode 100644 index 0000000000..2c9cf7bba2 --- /dev/null +++ b/tests/test_agent_output_schemas.py @@ -0,0 +1,552 @@ +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import ( + Agent, + BinaryImage, + RunContext, +) +from pydantic_ai._output import ( + NativeOutput, + PromptedOutput, +) +from pydantic_ai.output import StructuredDict, ToolOutput +from pydantic_ai.tools import DeferredToolRequests + +pytestmark = pytest.mark.anyio + + +async def test_text_output_json_schema(): + agent = Agent('test') + assert agent.output_json_schema() == snapshot({'type': 'string'}) + + +async def test_auto_output_json_schema(): + agent = Agent('test', output_type=bool) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'bool'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'bool', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_tool_output_json_schema(): + agent = Agent( + 'test', + output_type=[ToolOutput(bool, name='alice', description='Dreaming...')], + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + 'description': 'Dreaming...', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + agent = Agent( + 'test', + output_type=[ToolOutput(bool, name='alice'), ToolOutput(bool, name='bob')], + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'bob'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'bob', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_native_output_json_schema(): + agent = Agent( + 'test', + output_type=NativeOutput([bool], name='native_output_name', description='native_output_description'), + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'native_output_name'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'native_output_name', + 'description': 'native_output_description', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_prompted_output_json_schema(): + agent = Agent( + 'test', + output_type=PromptedOutput([bool], name='prompted_output_name', description='prompted_output_description'), + ) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'prompted_output_name'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'prompted_output_name', + 'description': 'prompted_output_description', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_custom_output_json_schema(): + HumanDict = StructuredDict( + { + 'type': 'object', + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + }, + name='Human', + description='A human with a name and age', + ) + agent = Agent('test', output_type=HumanDict) + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'Human'}, + 'data': { + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'Human', + 'description': 'A human with a name and age', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_image_output_json_schema(): + agent = Agent('test', output_type=BinaryImage) + assert agent.output_json_schema() == snapshot( + { + 'properties': { + 'data': {'format': 'binary', 'title': 'Data', 'type': 'string'}, + 'media_type': { + 'anyOf': [ + { + 'enum': ['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac'], + 'type': 'string', + }, + {'enum': ['image/jpeg', 'image/png', 'image/gif', 'image/webp'], 'type': 'string'}, + { + 'enum': [ + 'application/pdf', + 'text/plain', + 'text/csv', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'text/html', + 'text/markdown', + 'application/msword', + 'application/vnd.ms-excel', + ], + 'type': 'string', + }, + {'type': 'string'}, + ], + 'title': 'Media Type', + }, + 'vendor_metadata': { + 'anyOf': [{'additionalProperties': True, 'type': 'object'}, {'type': 'null'}], + 'default': None, + 'title': 'Vendor Metadata', + }, + 'identifier': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Identifier'}, + 'kind': {'const': 'binary', 'default': 'binary', 'title': 'Kind', 'type': 'string'}, + }, + 'required': ['data', 'media_type'], + 'title': 'BinaryImage', + 'type': 'object', + } + ) + agent = Agent('test', output_type=str | BinaryImage) + assert agent.output_json_schema() == snapshot( + { + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'BinaryImage'}, + 'data': { + 'properties': { + 'data': {'format': 'binary', 'type': 'string'}, + 'media_type': { + 'anyOf': [ + { + 'enum': [ + 'audio/wav', + 'audio/mpeg', + 'audio/ogg', + 'audio/flac', + 'audio/aiff', + 'audio/aac', + ], + 'type': 'string', + }, + { + 'enum': ['image/jpeg', 'image/png', 'image/gif', 'image/webp'], + 'type': 'string', + }, + { + 'enum': [ + 'application/pdf', + 'text/plain', + 'text/csv', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + 'text/html', + 'text/markdown', + 'application/msword', + 'application/vnd.ms-excel', + ], + 'type': 'string', + }, + {'type': 'string'}, + ] + }, + 'vendor_metadata': { + 'anyOf': [ + {'additionalProperties': True, 'type': 'object'}, + {'type': 'null'}, + ], + 'default': None, + }, + 'identifier': { + 'anyOf': [{'type': 'string'}, {'type': 'null'}], + 'default': None, + }, + 'kind': {'const': 'binary', 'default': 'binary', 'type': 'string'}, + }, + 'required': ['data', 'media_type'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'BinaryImage', + }, + ] + } + }, + 'required': ['result'], + 'type': 'object', + 'additionalProperties': False, + } + ) + + +async def test_override_output_json_schema(): + agent = Agent('test') + assert agent.output_json_schema() == snapshot({'type': 'string'}) + output_type = [ToolOutput(bool, name='alice', description='Dreaming...')] + assert agent.output_json_schema(output_type=output_type) == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'alice'}, + 'data': { + 'properties': {'response': {'type': 'boolean'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'alice', + 'description': 'Dreaming...', + } + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + } + ) + + +async def test_deferred_output_json_schema(): + agent = Agent('test', output_type=[str, DeferredToolRequests]) + + @agent.tool + def update_file(ctx: RunContext, path: str, content: str) -> str: + return '' + + @agent.tool_plain(requires_approval=True) + def delete_file(path: str) -> str: + return '' + + assert agent.output_json_schema() == snapshot( + { + 'type': 'object', + 'properties': { + 'result': { + 'anyOf': [ + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'str'}, + 'data': { + 'properties': {'response': {'type': 'string'}}, + 'required': ['response'], + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'str', + }, + { + 'type': 'object', + 'properties': { + 'kind': {'type': 'string', 'const': 'DeferredToolRequests'}, + 'data': { + 'properties': { + 'calls': {'items': {'$ref': '#/$defs/ToolCallPart'}, 'type': 'array'}, + 'approvals': {'items': {'$ref': '#/$defs/ToolCallPart'}, 'type': 'array'}, + 'metadata': { + 'additionalProperties': {'additionalProperties': True, 'type': 'object'}, + 'type': 'object', + }, + }, + 'type': 'object', + }, + }, + 'required': ['kind', 'data'], + 'additionalProperties': False, + 'title': 'DeferredToolRequests', + }, + ] + } + }, + 'required': ['result'], + 'additionalProperties': False, + '$defs': { + 'ToolCallPart': { + 'description': 'A tool call from a model.', + 'properties': { + 'tool_name': {'type': 'string'}, + 'args': { + 'anyOf': [ + {'type': 'string'}, + {'additionalProperties': True, 'type': 'object'}, + {'type': 'null'}, + ], + 'default': None, + }, + 'tool_call_id': {'type': 'string'}, + 'id': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None}, + 'part_kind': {'const': 'tool-call', 'default': 'tool-call', 'type': 'string'}, + }, + 'required': ['tool_name'], + 'title': 'ToolCallPart', + 'type': 'object', + } + }, + } + )