Skip to content
Draft
46 changes: 44 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_output.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations as _annotations

import copy
import inspect
import json
import re
Expand All @@ -15,6 +16,7 @@
from pydantic_ai._instrumentation import InstrumentationNames

from . import _function_schema, _utils, messages as _messages
from ._json_schema import JsonSchema
from ._run_context import AgentDepsT, RunContext
from .exceptions import ModelRetry, ToolRetryError, UserError
from .output import (
Expand Down Expand Up @@ -226,6 +228,10 @@ def mode(self) -> OutputMode:
def allows_text(self) -> bool:
return self.text_processor is not None

@abstractmethod
def dump(self) -> JsonSchema:
raise NotImplementedError()

@classmethod
def build( # noqa: C901
cls,
Expand Down Expand Up @@ -405,6 +411,18 @@ def __init__(
def mode(self) -> OutputMode:
return 'auto'

def dump(self) -> JsonSchema:
if self.toolset:
processors: list[ObjectOutputProcessor[OutputDataT]] = []
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
processor = copy.copy(self.toolset.processors[tool_def.name])
processor.object_def.name = tool_def.name
processor.object_def.description = tool_def.description
processors.append(processor)
return UnionOutputProcessor(processors).object_def.json_schema

return self.processor.object_def.json_schema


@dataclass(init=False)
class TextOutputSchema(OutputSchema[OutputDataT]):
Expand All @@ -425,6 +443,9 @@ def __init__(
def mode(self) -> OutputMode:
return 'text'

def dump(self) -> JsonSchema:
return {'type': 'string'}


class ImageOutputSchema(OutputSchema[OutputDataT]):
def __init__(self, *, allows_deferred_tools: bool):
Expand All @@ -434,6 +455,9 @@ def __init__(self, *, allows_deferred_tools: bool):
def mode(self) -> OutputMode:
return 'image'

def dump(self) -> JsonSchema:
raise NotImplementedError()


@dataclass(init=False)
class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC):
Expand All @@ -450,6 +474,9 @@ def __init__(
)
self.processor = processor

def dump(self) -> JsonSchema:
return self.processor.object_def.json_schema


class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
@property
Expand Down Expand Up @@ -523,6 +550,18 @@ def __init__(
def mode(self) -> OutputMode:
return 'tool'

def dump(self) -> JsonSchema:
if self.toolset is None:
# need to check expected behavior
raise NotImplementedError()
processors: list[ObjectOutputProcessor[OutputDataT]] = []
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
processor = copy.copy(self.toolset.processors[tool_def.name])
processor.object_def.name = tool_def.name
processor.object_def.description = tool_def.description
processors.append(processor)
return UnionOutputProcessor(processors).object_def.json_schema


class BaseOutputProcessor(ABC, Generic[OutputDataT]):
@abstractmethod
Expand Down Expand Up @@ -714,7 +753,7 @@ class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]):

def __init__(
self,
outputs: Sequence[OutputTypeOrFunction[OutputDataT]],
outputs: Sequence[OutputTypeOrFunction[OutputDataT] | ObjectOutputProcessor[OutputDataT]],
*,
name: str | None = None,
description: str | None = None,
Expand All @@ -725,7 +764,10 @@ def __init__(
json_schemas: list[ObjectJsonSchema] = []
self._processors = {}
for output in outputs:
processor = ObjectOutputProcessor(output=output, strict=strict)
if isinstance(output, ObjectOutputProcessor):
processor = output
else:
processor = ObjectOutputProcessor(output=output, strict=strict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we need to change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modify UnionOutputProcessor so that toolset output processors can be unioned with additional output types (like str, BinaryImage, DeferredToolRequests). Maybe I shouldn't be using toolset.processors?

object_def = processor.object_def

object_key = object_def.name or output.__name__
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UserPromptNode,
capture_run_messages,
)
from .._json_schema import JsonSchema
from .._output import OutputToolset
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
Expand Down Expand Up @@ -947,6 +948,11 @@ def decorator(
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
return func

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
output_schema = self._prepare_output_schema(output_type)
return output_schema.dump()

@overload
def output_validator(
self, func: Callable[[RunContext[AgentDepsT], OutputDataT], OutputDataT], /
Expand Down
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
result,
usage as _usage,
)
from .._json_schema import JsonSchema
from .._tool_manager import ToolManager
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -122,6 +123,11 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
"""
raise NotImplementedError

@abstractmethod
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
"""The output JSON schema."""
raise NotImplementedError

@overload
async def run(
self,
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
models,
usage as _usage,
)
from .._json_schema import JsonSchema
from ..builtin_tools import AbstractBuiltinTool
from ..output import OutputDataT, OutputSpec
from ..run import AgentRun
Expand Down Expand Up @@ -67,6 +68,9 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
async def __aexit__(self, *args: Any) -> bool | None:
return await self.wrapped.__aexit__(*args)

def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
return self.wrapped.output_json_schema(output_type=output_type)

@overload
def iter(
self,
Expand Down
Loading
Loading