-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add output_json_schema method to Agent class #3454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 12 commits
0ee01f1
80dd6e3
7127de6
291fa1f
d1b2399
4454583
fdc8820
93de392
79253f1
ba9433f
09184e2
0d58762
0043115
123eabc
66f26b5
74ac9b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from __future__ import annotations as _annotations | ||
|
|
||
| import copy | ||
| import inspect | ||
| import json | ||
| import re | ||
|
|
@@ -15,6 +16,7 @@ | |
| from pydantic_ai._instrumentation import InstrumentationNames | ||
|
|
||
| from . import _function_schema, _utils, messages as _messages | ||
| from ._json_schema import JsonSchema | ||
| from ._run_context import AgentDepsT, RunContext | ||
| from .exceptions import ModelRetry, ToolRetryError, UserError | ||
| from .output import ( | ||
|
|
@@ -226,6 +228,10 @@ def mode(self) -> OutputMode: | |
| def allows_text(self) -> bool: | ||
| return self.text_processor is not None | ||
|
|
||
| @abstractmethod | ||
| def dump(self) -> JsonSchema: | ||
| raise NotImplementedError() | ||
|
|
||
| @classmethod | ||
| def build( # noqa: C901 | ||
| cls, | ||
|
|
@@ -405,6 +411,18 @@ def __init__( | |
| def mode(self) -> OutputMode: | ||
| return 'auto' | ||
|
|
||
| def dump(self) -> JsonSchema: | ||
| if self.toolset: | ||
| processors: list[ObjectOutputProcessor[OutputDataT]] = [] | ||
| for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage] | ||
| processor = copy.copy(self.toolset.processors[tool_def.name]) | ||
| processor.object_def.name = tool_def.name | ||
| processor.object_def.description = tool_def.description | ||
| processors.append(processor) | ||
| return UnionOutputProcessor(processors).object_def.json_schema | ||
|
|
||
| return self.processor.object_def.json_schema | ||
g-eoj marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class TextOutputSchema(OutputSchema[OutputDataT]): | ||
|
|
@@ -425,6 +443,9 @@ def __init__( | |
| def mode(self) -> OutputMode: | ||
| return 'text' | ||
|
|
||
| def dump(self) -> JsonSchema: | ||
| return {'type': 'string'} | ||
|
|
||
|
|
||
| class ImageOutputSchema(OutputSchema[OutputDataT]): | ||
| def __init__(self, *, allows_deferred_tools: bool): | ||
|
|
@@ -434,6 +455,9 @@ def __init__(self, *, allows_deferred_tools: bool): | |
| def mode(self) -> OutputMode: | ||
| return 'image' | ||
|
|
||
| def dump(self) -> JsonSchema: | ||
| raise NotImplementedError() | ||
g-eoj marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @dataclass(init=False) | ||
| class StructuredTextOutputSchema(OutputSchema[OutputDataT], ABC): | ||
|
|
@@ -450,6 +474,9 @@ def __init__( | |
| ) | ||
| self.processor = processor | ||
|
|
||
| def dump(self) -> JsonSchema: | ||
| return self.processor.object_def.json_schema | ||
|
|
||
|
|
||
| class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]): | ||
| @property | ||
|
|
@@ -523,6 +550,18 @@ def __init__( | |
| def mode(self) -> OutputMode: | ||
| return 'tool' | ||
|
|
||
| def dump(self) -> JsonSchema: | ||
| if self.toolset is None: | ||
| # need to check expected behavior | ||
| raise NotImplementedError() | ||
| processors: list[ObjectOutputProcessor[OutputDataT]] = [] | ||
| for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage] | ||
| processor = copy.copy(self.toolset.processors[tool_def.name]) | ||
| processor.object_def.name = tool_def.name | ||
| processor.object_def.description = tool_def.description | ||
| processors.append(processor) | ||
| return UnionOutputProcessor(processors).object_def.json_schema | ||
g-eoj marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| class BaseOutputProcessor(ABC, Generic[OutputDataT]): | ||
| @abstractmethod | ||
|
|
@@ -714,7 +753,7 @@ class UnionOutputProcessor(BaseObjectOutputProcessor[OutputDataT]): | |
|
|
||
| def __init__( | ||
| self, | ||
| outputs: Sequence[OutputTypeOrFunction[OutputDataT]], | ||
| outputs: Sequence[OutputTypeOrFunction[OutputDataT] | ObjectOutputProcessor[OutputDataT]], | ||
| *, | ||
| name: str | None = None, | ||
| description: str | None = None, | ||
|
|
@@ -725,7 +764,10 @@ def __init__( | |
| json_schemas: list[ObjectJsonSchema] = [] | ||
| self._processors = {} | ||
| for output in outputs: | ||
| processor = ObjectOutputProcessor(output=output, strict=strict) | ||
| if isinstance(output, ObjectOutputProcessor): | ||
| processor = output | ||
| else: | ||
| processor = ObjectOutputProcessor(output=output, strict=strict) | ||
|
||
| object_def = processor.object_def | ||
|
|
||
| object_key = object_def.name or output.__name__ | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.