Skip to content

Commit 64b61d1

Browse files
committed
Use Agent.output_types to construct JSON schema
1 parent a0e8b24 commit 64b61d1

File tree

6 files changed

+241
-178
lines changed

6 files changed

+241
-178
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from abc import ABC, abstractmethod
77
from collections.abc import Awaitable, Callable, Sequence
88
from dataclasses import dataclass, field
9-
from functools import cached_property
109
from typing import TYPE_CHECKING, Any, Generic, Literal, cast, overload
1110

1211
from pydantic import Json, TypeAdapter, ValidationError
@@ -16,7 +15,6 @@
1615
from pydantic_ai._instrumentation import InstrumentationNames
1716

1817
from . import _function_schema, _utils, messages as _messages
19-
from ._json_schema import JsonSchema
2018
from ._run_context import AgentDepsT, RunContext
2119
from .exceptions import ModelRetry, ToolRetryError, UserError
2220
from .output import (
@@ -228,10 +226,6 @@ def mode(self) -> OutputMode:
228226
def allows_text(self) -> bool:
229227
return self.text_processor is not None
230228

231-
@cached_property
232-
def json_schema(self) -> JsonSchema:
233-
raise NotImplementedError()
234-
235229
@classmethod
236230
def build( # noqa: C901
237231
cls,
@@ -385,56 +379,6 @@ def _build_processor(
385379

386380
return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description)
387381

388-
def build_json_schema(self) -> JsonSchema: # noqa: C901
389-
# allow any output with {'type': 'string'} if no constraints
390-
if not any([self.allows_deferred_tools, self.allows_image, self.object_def, self.toolset]):
391-
return TypeAdapter(str).json_schema()
392-
393-
json_schemas: list[ObjectJsonSchema] = []
394-
395-
processor = getattr(self, 'processor', None)
396-
if isinstance(processor, ObjectOutputProcessor):
397-
json_schema = processor.object_def.json_schema
398-
if k := processor.outer_typed_dict_key:
399-
json_schema = json_schema['properties'][k]
400-
json_schemas.append(json_schema)
401-
402-
elif self.toolset:
403-
if self.allows_text:
404-
json_schema = TypeAdapter(str).json_schema()
405-
json_schemas.append(json_schema)
406-
for tool_processor in self.toolset.processors.values():
407-
json_schema = tool_processor.object_def.json_schema
408-
if k := tool_processor.outer_typed_dict_key:
409-
json_schema = json_schema['properties'][k]
410-
if json_schema not in json_schemas:
411-
json_schemas.append(json_schema)
412-
413-
elif self.allows_text:
414-
json_schema = TypeAdapter(str).json_schema()
415-
json_schemas.append(json_schema)
416-
417-
if self.allows_deferred_tools:
418-
json_schema = TypeAdapter(DeferredToolRequests).json_schema(mode='serialization')
419-
if json_schema not in json_schemas:
420-
json_schemas.append(json_schema)
421-
422-
if self.allows_image:
423-
json_schema = TypeAdapter(_messages.BinaryImage).json_schema()
424-
json_schema = {k: v for k, v in json_schema['properties'].items() if k in ['data', 'media_type']}
425-
if json_schema not in json_schemas:
426-
json_schemas.append(json_schema)
427-
428-
if len(json_schemas) == 1:
429-
return json_schemas[0]
430-
431-
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
432-
json_schema: JsonSchema = {'anyOf': json_schemas}
433-
if all_defs:
434-
json_schema['$defs'] = all_defs
435-
436-
return json_schema
437-
438382

439383
@dataclass(init=False)
440384
class AutoOutputSchema(OutputSchema[OutputDataT]):
@@ -463,10 +407,6 @@ def __init__(
463407
def mode(self) -> OutputMode:
464408
return 'auto'
465409

466-
@cached_property
467-
def json_schema(self) -> JsonSchema:
468-
return self.build_json_schema()
469-
470410

471411
@dataclass(init=False)
472412
class TextOutputSchema(OutputSchema[OutputDataT]):
@@ -487,10 +427,6 @@ def __init__(
487427
def mode(self) -> OutputMode:
488428
return 'text'
489429

490-
@cached_property
491-
def json_schema(self) -> JsonSchema:
492-
return self.build_json_schema()
493-
494430

495431
class ImageOutputSchema(OutputSchema[OutputDataT]):
496432
def __init__(self, *, allows_deferred_tools: bool):
@@ -500,10 +436,6 @@ def __init__(self, *, allows_deferred_tools: bool):
500436
def mode(self) -> OutputMode:
501437
return 'image'
502438

503-
@cached_property
504-
def json_schema(self) -> JsonSchema:
505-
return self.build_json_schema()
506-
507439

508440
@dataclass(init=False)
509441
class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC):
@@ -520,10 +452,6 @@ def __init__(
520452
)
521453
self.processor = processor
522454

523-
@cached_property
524-
def json_schema(self) -> JsonSchema:
525-
return self.build_json_schema()
526-
527455

528456
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
529457
@property
@@ -590,10 +518,6 @@ def __init__(
590518
def mode(self) -> OutputMode:
591519
return 'tool'
592520

593-
@cached_property
594-
def json_schema(self) -> JsonSchema:
595-
return self.build_json_schema()
596-
597521

598522
class BaseOutputProcessor(ABC, Generic[OutputDataT]):
599523
@abstractmethod

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
from typing import TYPE_CHECKING, Any, ClassVar, overload
1212

1313
from opentelemetry.trace import NoOpTracer, use_span
14+
from pydantic import TypeAdapter
1415
from pydantic.json_schema import GenerateJsonSchema
1516
from typing_extensions import Self, TypeVar, deprecated
1617

1718
from pydantic_ai._instrumentation import DEFAULT_INSTRUMENTATION_VERSION, InstrumentationNames
1819

1920
from .. import (
2021
_agent_graph,
22+
_function_schema,
2123
_output,
2224
_system_prompt,
2325
_utils,
@@ -956,10 +958,49 @@ def decorator(
956958
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
957959
return func
958960

959-
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
961+
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
960962
"""The output JSON schema."""
961-
output_schema = self._prepare_output_schema(output_type)
962-
return output_schema.json_schema
963+
if output_type is None:
964+
output_type = self.output_type
965+
966+
# call this first to force output_type to an iterable
967+
output_type = list(_output._flatten_output_spec(output_type))
968+
969+
# flatten special outputs
970+
for i, _ in enumerate(output_type):
971+
if isinstance(_, _output.NativeOutput):
972+
output_type[i] = _output._flatten_output_spec(_.outputs)
973+
if isinstance(_, _output.PromptedOutput):
974+
output_type[i] = _output._flatten_output_spec(_.outputs)
975+
if isinstance(_, _output.ToolOutput):
976+
output_type[i] = _output._flatten_output_spec(_.output)
977+
978+
# final flattening
979+
output_type = _output._flatten_output_spec(output_type)
980+
981+
json_schemas: list[JsonSchema] = []
982+
for _ in output_type:
983+
if inspect.isfunction(_) or inspect.ismethod(_):
984+
function_schema = _function_schema.function_schema(_, GenerateToolJsonSchema)
985+
json_schema = function_schema.json_schema
986+
json_schema['description'] = function_schema.description
987+
elif isinstance(_, _messages.BinaryImage):
988+
json_schema = TypeAdapter(_).json_schema(mode='serialization')
989+
json_schema = {k: v for k, v in json_schema['properties'].items() if k in ['data', 'media_type']}
990+
else:
991+
json_schema = TypeAdapter(_).json_schema(mode='serialization')
992+
993+
if json_schema not in json_schemas:
994+
json_schemas.append(json_schema)
995+
996+
if len(json_schemas) == 1:
997+
return json_schemas[0]
998+
else:
999+
json_schemas, all_defs = _utils.merge_json_schema_defs(json_schemas)
1000+
json_schema: JsonSchema = {'anyOf': json_schemas}
1001+
if all_defs:
1002+
json_schema['$defs'] = all_defs
1003+
return json_schema
9631004

9641005
@overload
9651006
def output_validator(

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
124124
raise NotImplementedError
125125

126126
@abstractmethod
127-
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
127+
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
128128
"""The output JSON schema."""
129129
raise NotImplementedError
130130

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
6868
async def __aexit__(self, *args: Any) -> bool | None:
6969
return await self.wrapped.__aexit__(*args)
7070

71-
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
71+
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
7272
return self.wrapped.output_json_schema(output_type=output_type)
7373

7474
@overload

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5251,7 +5251,7 @@ def foo() -> str:
52515251
assert wrapper_agent.output_json_schema() == snapshot(
52525252
{
52535253
'type': 'object',
5254-
'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}},
5254+
'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}},
52555255
'title': 'Foo',
52565256
'required': ['a', 'b'],
52575257
}

0 commit comments

Comments
 (0)