Skip to content

Commit d9ae4ae

Browse files
committed
tweaks and coverage
1 parent ae410f1 commit d9ae4ae

File tree

10 files changed

+51
-57
lines changed

10 files changed

+51
-57
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
140140
end_strategy: EndStrategy
141141
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
142142

143-
output_schema: _output.BaseOutputSchema[OutputDataT]
143+
output_schema: _output.OutputSchema[OutputDataT]
144144
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
145145

146146
history_processors: Sequence[HistoryProcessor[DepsT]]

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -211,28 +211,21 @@ async def validate(
211211

212212

213213
@dataclass(kw_only=True)
214-
class BaseOutputSchema(ABC, Generic[OutputDataT]):
214+
class OutputSchema(ABC, Generic[OutputDataT]):
215215
text_processor: BaseOutputProcessor[OutputDataT] | None = None
216216
toolset: OutputToolset[Any] | None = None
217217
object_def: OutputObjectDefinition | None = None
218218
allows_deferred_tools: bool = False
219219
allows_image: bool = False
220220

221221
@property
222-
def mode(self) -> OutputMode | None:
222+
def mode(self) -> OutputMode:
223223
raise NotImplementedError()
224224

225225
@property
226226
def allows_text(self) -> bool:
227227
return self.text_processor is not None
228228

229-
230-
@dataclass(init=False)
231-
class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
232-
"""Model the final output from an agent run."""
233-
234-
# TODO (DouweM): Rename/merge this, BaseOutputSchema, and OutputSchemaWithoutMode
235-
236229
@classmethod
237230
def build( # noqa: C901
238231
cls,
@@ -241,7 +234,7 @@ def build( # noqa: C901
241234
name: str | None = None,
242235
description: str | None = None,
243236
strict: bool | None = None,
244-
) -> BaseOutputSchema[OutputDataT]:
237+
) -> OutputSchema[OutputDataT]:
245238
"""Build an OutputSchema dataclass from an output type."""
246239
outputs = _flatten_output_spec(output_spec)
247240

@@ -359,7 +352,7 @@ def build( # noqa: C901
359352
)
360353

361354
if len(other_outputs) > 0:
362-
return OutputSchemaWithoutMode(
355+
return AutoOutputSchema(
363356
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
364357
toolset=toolset,
365358
allows_deferred_tools=allows_deferred_tools,
@@ -386,7 +379,7 @@ def _build_processor(
386379

387380

388381
@dataclass(init=False)
389-
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
382+
class AutoOutputSchema(OutputSchema[OutputDataT]):
390383
processor: BaseObjectOutputProcessor[OutputDataT]
391384

392385
def __init__(
@@ -400,18 +393,17 @@ def __init__(
400393
# At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time,
401394
# but we cover ourselves just in case we end up using the tool output mode.
402395
super().__init__(
403-
allows_deferred_tools=allows_deferred_tools,
404396
toolset=toolset,
405397
object_def=processor.object_def,
406398
text_processor=processor,
399+
allows_deferred_tools=allows_deferred_tools,
407400
allows_image=allows_image,
408401
)
409402
self.processor = processor
410403

411404
@property
412-
def mode(self) -> OutputMode | None:
413-
# TODO (DouweM): Could this be a field?
414-
return None
405+
def mode(self) -> OutputMode:
406+
return 'auto'
415407

416408

417409
@dataclass(init=False)
@@ -430,7 +422,7 @@ def __init__(
430422
)
431423

432424
@property
433-
def mode(self) -> OutputMode | None:
425+
def mode(self) -> OutputMode:
434426
return 'text'
435427

436428

@@ -439,7 +431,7 @@ def __init__(self, *, allows_deferred_tools: bool):
439431
super().__init__(allows_deferred_tools=allows_deferred_tools, allows_image=True)
440432

441433
@property
442-
def mode(self) -> OutputMode | None:
434+
def mode(self) -> OutputMode:
443435
return 'image'
444436

445437

@@ -461,7 +453,7 @@ def __init__(
461453

462454
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
463455
@property
464-
def mode(self) -> OutputMode | None:
456+
def mode(self) -> OutputMode:
465457
return 'native'
466458

467459

@@ -485,7 +477,7 @@ def __init__(
485477
self.template = template
486478

487479
@property
488-
def mode(self) -> OutputMode | None:
480+
def mode(self) -> OutputMode:
489481
return 'prompted'
490482

491483
@classmethod
@@ -528,7 +520,7 @@ def __init__(
528520
)
529521

530522
@property
531-
def mode(self) -> OutputMode | None:
523+
def mode(self) -> OutputMode:
532524
return 'tool'
533525

534526

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]):
132132
_instrument_default: ClassVar[InstrumentationSettings | bool] = False
133133

134134
_deps_type: type[AgentDepsT] = dataclasses.field(repr=False)
135-
_output_schema: _output.BaseOutputSchema[OutputDataT] = dataclasses.field(repr=False)
135+
_output_schema: _output.OutputSchema[OutputDataT] = dataclasses.field(repr=False)
136136
_output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] = dataclasses.field(repr=False)
137137
_instructions: list[str | _system_prompt.SystemPromptFunc[AgentDepsT]] = dataclasses.field(repr=False)
138138
_system_prompts: tuple[str, ...] = dataclasses.field(repr=False)
@@ -1409,14 +1409,14 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
14091409
return toolsets
14101410

14111411
@overload
1412-
def _prepare_output_schema(self, output_type: None) -> _output.BaseOutputSchema[OutputDataT]: ...
1412+
def _prepare_output_schema(self, output_type: None) -> _output.OutputSchema[OutputDataT]: ...
14131413

14141414
@overload
14151415
def _prepare_output_schema(
14161416
self, output_type: OutputSpec[RunOutputDataT]
1417-
) -> _output.BaseOutputSchema[RunOutputDataT]: ...
1417+
) -> _output.OutputSchema[RunOutputDataT]: ...
14181418

1419-
def _prepare_output_schema(self, output_type: OutputSpec[Any] | None) -> _output.BaseOutputSchema[Any]:
1419+
def _prepare_output_schema(self, output_type: OutputSpec[Any] | None) -> _output.OutputSchema[Any]:
14201420
if output_type is not None:
14211421
if self._output_validators:
14221422
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
@@ -1494,15 +1494,15 @@ async def run_mcp_servers(
14941494

14951495
@dataclasses.dataclass(init=False)
14961496
class _AgentFunctionToolset(FunctionToolset[AgentDepsT]):
1497-
output_schema: _output.BaseOutputSchema[Any]
1497+
output_schema: _output.OutputSchema[Any]
14981498

14991499
def __init__(
15001500
self,
15011501
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [],
15021502
*,
15031503
max_retries: int = 1,
15041504
id: str | None = None,
1505-
output_schema: _output.BaseOutputSchema[Any],
1505+
output_schema: _output.OutputSchema[Any],
15061506
):
15071507
self.output_schema = output_schema
15081508
super().__init__(tools, max_retries=max_retries, id=id)

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ class ModelRequestParameters:
306306
function_tools: list[ToolDefinition] = field(default_factory=list)
307307
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)
308308

309-
output_mode: OutputMode | None = 'text' # TODO (DouweM): None or new `'auto'` value? AutoOutputSchema is not bad.
309+
output_mode: OutputMode = 'text'
310310
output_object: OutputObjectDefinition | None = None
311311
output_tools: list[ToolDefinition] = field(default_factory=list)
312312
prompted_output_template: str | None = None
@@ -426,8 +426,7 @@ def prepare_request(
426426
builtin_tools=list({tool.unique_id: tool for tool in builtin_tools}.values()),
427427
)
428428

429-
if not model_request_parameters.output_mode:
430-
assert model_request_parameters.output_tools or model_request_parameters.output_object
429+
if model_request_parameters.output_mode == 'auto':
431430
output_mode = self.profile.default_structured_output_mode
432431
model_request_parameters = replace(
433432
model_request_parameters,
@@ -436,31 +435,33 @@ def prepare_request(
436435
)
437436

438437
if model_request_parameters.output_mode in ('native', 'prompted'):
439-
if not model_request_parameters.output_object:
440-
raise UserError( # pragma: no cover
441-
'An `output_object` is required when using `NativeOutput` or `PromptedOutput`.'
442-
)
443-
444-
if model_request_parameters.output_mode == 'native' and not self.profile.supports_json_schema_output:
445-
raise UserError('Native structured output is not supported by this model.')
438+
assert model_request_parameters.output_object
446439

447440
if model_request_parameters.output_tools:
448441
model_request_parameters = replace(model_request_parameters, output_tools=[])
449-
450-
if not model_request_parameters.prompted_output_template:
451-
model_request_parameters = replace(
452-
model_request_parameters, prompted_output_template=self.profile.prompted_output_template
453-
)
454442
else:
455-
if model_request_parameters.output_mode == 'tool':
456-
if not model_request_parameters.output_tools and not model_request_parameters.function_tools:
457-
raise UserError('An `output_tools` list is required when using `ToolOutput`.') # pragma: no cover
443+
if model_request_parameters.output_object:
444+
model_request_parameters = replace(model_request_parameters, output_object=None)
445+
446+
match model_request_parameters.output_mode:
447+
case 'native':
448+
if not self.profile.supports_json_schema_output:
449+
raise UserError('Native structured output is not supported by this model.')
450+
451+
if model_request_parameters.prompted_output_template:
452+
model_request_parameters = replace(model_request_parameters, prompted_output_template=None)
453+
case 'prompted':
454+
if not model_request_parameters.prompted_output_template:
455+
model_request_parameters = replace(
456+
model_request_parameters, prompted_output_template=self.profile.prompted_output_template
457+
)
458+
case 'tool':
459+
assert model_request_parameters.output_tools or model_request_parameters.function_tools
458460

459461
if not self.profile.supports_tools:
460462
raise UserError('Tool output is not supported by this model.')
461-
462-
if model_request_parameters.output_object:
463-
model_request_parameters = replace(model_request_parameters, output_object=None)
463+
case _:
464+
pass
464465

465466
if model_request_parameters.allow_image_output and not self.profile.supports_image_output:
466467
raise UserError('Image output is not supported by this model.')

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,15 +243,14 @@ async def request_stream(
243243
def prepare_request(
244244
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
245245
) -> tuple[ModelSettings | None, ModelRequestParameters]:
246-
# TODO (DouweM): Dedupe with super
247246
settings = merge_model_settings(self.settings, model_settings)
248247
if (
249248
model_request_parameters.output_tools
250249
and settings
251250
and (thinking := settings.get('anthropic_thinking'))
252251
and thinking.get('type') == 'enabled'
253252
):
254-
if model_request_parameters.output_mode is None:
253+
if model_request_parameters.output_mode == 'auto':
255254
model_request_parameters = replace(model_request_parameters, output_mode='prompted')
256255
elif model_request_parameters.output_mode == 'tool' and not model_request_parameters.allow_text_output:
257256
# This would result in `tool_choice=required`, which Anthropic does not support with thinking.

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def prepare_request(
228228
self, model_settings: ModelSettings | None, model_request_parameters: ModelRequestParameters
229229
) -> tuple[ModelSettings | None, ModelRequestParameters]:
230230
if model_request_parameters.builtin_tools and model_request_parameters.output_tools:
231-
if model_request_parameters.output_mode is None:
231+
if model_request_parameters.output_mode == 'auto':
232232
model_request_parameters = replace(model_request_parameters, output_mode='prompted')
233233
else:
234234
raise UserError(

pydantic_ai_slim/pydantic_ai/models/outlines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ async def _process_streamed_response(
528528

529529
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
530530
"""Customize the model request parameters for the model."""
531-
if model_request_parameters.output_mode in (None, 'native'):
531+
if model_request_parameters.output_mode in ('auto', 'native'):
532532
# This way the JSON schema will be included in the instructions.
533533
return replace(model_request_parameters, output_mode='prompted')
534534
else:

pydantic_ai_slim/pydantic_ai/output.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@
3737
OutputDataT = TypeVar('OutputDataT', default=str, covariant=True)
3838
"""Covariant type variable for the output data type of a run."""
3939

40-
OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image']
40+
OutputMode = Literal['text', 'tool', 'native', 'prompted', 'tool_or_text', 'image', 'auto']
4141
"""All output modes.
4242
43-
`tool_or_text` is deprecated and no longer in use.
43+
- `tool_or_text` is deprecated and no longer in use.
44+
- `auto` means the model will automatically choose a structured output mode based on the model's `ModelProfile.default_structured_output_mode`.
4445
"""
4546
StructuredOutputMode = Literal['tool', 'native', 'prompted']
4647
"""Output modes that can be used for structured output. Used by ModelProfile.default_structured_output_mode"""

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from . import _utils, exceptions, messages as _messages, models
1313
from ._output import (
14-
BaseOutputSchema,
1514
OutputDataT_inv,
15+
OutputSchema,
1616
OutputValidator,
1717
OutputValidatorFunc,
1818
TextOutputSchema,
@@ -46,7 +46,7 @@
4646
@dataclass(kw_only=True)
4747
class AgentStream(Generic[AgentDepsT, OutputDataT]):
4848
_raw_stream_response: models.StreamedResponse
49-
_output_schema: BaseOutputSchema[OutputDataT]
49+
_output_schema: OutputSchema[OutputDataT]
5050
_model_request_parameters: models.ModelRequestParameters
5151
_output_validators: list[OutputValidator[AgentDepsT, OutputDataT]]
5252
_run_ctx: RunContext[AgentDepsT]

tests/test_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3791,6 +3791,7 @@ def get_image() -> BinaryContent:
37913791
BinaryContent(
37923792
data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82',
37933793
media_type='image/png',
3794+
_identifier='image_id_1',
37943795
),
37953796
],
37963797
timestamp=IsNow(tz=timezone.utc),

0 commit comments

Comments
 (0)