Skip to content

Commit a4a81cc

Browse files
committed
Support Native and Prompted output modes when using FallbackModel
1 parent cedee4a commit a4a81cc

File tree

17 files changed

+200
-148
lines changed

17 files changed

+200
-148
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
140140
end_strategy: EndStrategy
141141
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
142142

143-
output_schema: _output.OutputSchema[OutputDataT]
143+
output_schema: _output.BaseOutputSchema[OutputDataT]
144144
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
145145

146146
history_processors: Sequence[HistoryProcessor[DepsT]]
@@ -374,9 +374,10 @@ async def _prepare_request_parameters(
374374
) -> models.ModelRequestParameters:
375375
"""Build tools and create an agent model."""
376376
output_schema = ctx.deps.output_schema
377-
output_object = None
378-
if isinstance(output_schema, _output.NativeOutputSchema):
379-
output_object = output_schema.object_def
377+
378+
prompted_output_template = (
379+
output_schema.template if isinstance(output_schema, _output.PromptedOutputSchema) else None
380+
)
380381

381382
function_tools: list[ToolDefinition] = []
382383
output_tools: list[ToolDefinition] = []
@@ -391,7 +392,8 @@ async def _prepare_request_parameters(
391392
builtin_tools=ctx.deps.builtin_tools,
392393
output_mode=output_schema.mode,
393394
output_tools=output_tools,
394-
output_object=output_object,
395+
output_object=output_schema.object_def,
396+
prompted_output_template=prompted_output_template,
395397
allow_text_output=output_schema.allows_text,
396398
allow_image_output=output_schema.allows_image,
397399
)
@@ -489,7 +491,6 @@ async def _prepare_request(
489491
message_history = _clean_message_history(message_history)
490492

491493
model_request_parameters = await _prepare_request_parameters(ctx)
492-
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
493494

494495
model_settings = ctx.deps.model_settings
495496
usage = ctx.state.usage

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 22 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pydantic import Json, TypeAdapter, ValidationError
1212
from pydantic_core import SchemaValidator, to_json
13-
from typing_extensions import Self, TypedDict, TypeVar, assert_never
13+
from typing_extensions import Self, TypedDict, TypeVar
1414

1515
from pydantic_ai._instrumentation import InstrumentationNames
1616

@@ -215,11 +215,12 @@ async def validate(
215215
class BaseOutputSchema(ABC, Generic[OutputDataT]):
216216
text_processor: BaseOutputProcessor[OutputDataT] | None = None
217217
toolset: OutputToolset[Any] | None = None
218+
object_def: OutputObjectDefinition | None = None
218219
allows_deferred_tools: bool = False
219220
allows_image: bool = False
220221

221-
@abstractmethod
222-
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
222+
@property
223+
def mode(self) -> OutputMode | None:
223224
raise NotImplementedError()
224225

225226
@property
@@ -231,6 +232,8 @@ def allows_text(self) -> bool:
231232
class OutputSchema(BaseOutputSchema[OutputDataT], ABC):
232233
"""Model the final output from an agent run."""
233234

235+
# TODO (DouweM): Rename/merge this, BaseOutputSchema, and OutputSchemaWithoutMode
236+
234237
@classmethod
235238
@overload
236239
def build(
@@ -260,7 +263,6 @@ def build( # noqa: C901
260263
cls,
261264
output_spec: OutputSpec[OutputDataT],
262265
*,
263-
default_mode: StructuredOutputMode | None = None,
264266
name: str | None = None,
265267
description: str | None = None,
266268
strict: bool | None = None,
@@ -382,15 +384,12 @@ def build( # noqa: C901
382384
)
383385

384386
if len(other_outputs) > 0:
385-
schema = OutputSchemaWithoutMode(
387+
return OutputSchemaWithoutMode(
386388
processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict),
387389
toolset=toolset,
388390
allows_deferred_tools=allows_deferred_tools,
389391
allows_image=allows_image,
390392
)
391-
if default_mode:
392-
schema = schema.with_default_mode(default_mode)
393-
return schema
394393

395394
if allows_image:
396395
return ImageOutputSchema(allows_deferred_tools=allows_deferred_tools)
@@ -410,19 +409,12 @@ def _build_processor(
410409

411410
return UnionOutputProcessor(outputs=outputs, strict=strict, name=name, description=description)
412411

413-
@property
414-
@abstractmethod
415-
def mode(self) -> OutputMode:
416-
raise NotImplementedError()
417-
418412
def raise_if_unsupported(self, profile: ModelProfile) -> None:
419413
"""Raise an error if the mode is not supported by this model."""
414+
# TODO (DouweM): Remove method?
420415
if self.allows_image and not profile.supports_image_output:
421416
raise UserError('Image output is not supported by this model.')
422417

423-
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
424-
return self
425-
426418

427419
@dataclass(init=False)
428420
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
@@ -441,30 +433,16 @@ def __init__(
441433
super().__init__(
442434
allows_deferred_tools=allows_deferred_tools,
443435
toolset=toolset,
436+
object_def=processor.object_def,
444437
text_processor=processor,
445438
allows_image=allows_image,
446439
)
447440
self.processor = processor
448441

449-
def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]:
450-
if mode == 'native':
451-
return NativeOutputSchema(
452-
processor=self.processor,
453-
allows_deferred_tools=self.allows_deferred_tools,
454-
allows_image=self.allows_image,
455-
)
456-
elif mode == 'prompted':
457-
return PromptedOutputSchema(
458-
processor=self.processor,
459-
allows_deferred_tools=self.allows_deferred_tools,
460-
allows_image=self.allows_image,
461-
)
462-
elif mode == 'tool':
463-
return ToolOutputSchema(
464-
toolset=self.toolset, allows_deferred_tools=self.allows_deferred_tools, allows_image=self.allows_image
465-
)
466-
else:
467-
assert_never(mode)
442+
@property
443+
def mode(self) -> OutputMode | None:
444+
# TODO (DouweM): Could this be a field?
445+
return None
468446

469447

470448
@dataclass(init=False)
@@ -483,7 +461,7 @@ def __init__(
483461
)
484462

485463
@property
486-
def mode(self) -> OutputMode:
464+
def mode(self) -> OutputMode | None:
487465
return 'text'
488466

489467
def raise_if_unsupported(self, profile: ModelProfile) -> None:
@@ -496,7 +474,7 @@ def __init__(self, *, allows_deferred_tools: bool):
496474
super().__init__(allows_deferred_tools=allows_deferred_tools, allows_image=True)
497475

498476
@property
499-
def mode(self) -> OutputMode:
477+
def mode(self) -> OutputMode | None:
500478
return 'image'
501479

502480
def raise_if_unsupported(self, profile: ModelProfile) -> None:
@@ -513,18 +491,17 @@ def __init__(
513491
self, *, processor: BaseObjectOutputProcessor[OutputDataT], allows_deferred_tools: bool, allows_image: bool
514492
):
515493
super().__init__(
516-
text_processor=processor, allows_deferred_tools=allows_deferred_tools, allows_image=allows_image
494+
text_processor=processor,
495+
object_def=processor.object_def,
496+
allows_deferred_tools=allows_deferred_tools,
497+
allows_image=allows_image,
517498
)
518499
self.processor = processor
519500

520-
@property
521-
def object_def(self) -> OutputObjectDefinition:
522-
return self.processor.object_def
523-
524501

525502
class NativeOutputSchema(StructuredTextOutputSchema[OutputDataT]):
526503
@property
527-
def mode(self) -> OutputMode:
504+
def mode(self) -> OutputMode | None:
528505
return 'native'
529506

530507
def raise_if_unsupported(self, profile: ModelProfile) -> None:
@@ -553,7 +530,7 @@ def __init__(
553530
self.template = template
554531

555532
@property
556-
def mode(self) -> OutputMode:
533+
def mode(self) -> OutputMode | None:
557534
return 'prompted'
558535

559536
@classmethod
@@ -599,7 +576,7 @@ def __init__(
599576
)
600577

601578
@property
602-
def mode(self) -> OutputMode:
579+
def mode(self) -> OutputMode | None:
603580
return 'tool'
604581

605582
def raise_if_unsupported(self, profile: ModelProfile) -> None:

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -303,11 +303,7 @@ def __init__(
303303

304304
_utils.validate_empty_kwargs(_deprecated_kwargs)
305305

306-
default_output_mode = (
307-
self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None
308-
)
309-
310-
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode)
306+
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type)
311307
self._output_validators = []
312308

313309
self._instructions = self._normalize_instructions(instructions)
@@ -545,7 +541,7 @@ async def main():
545541
del model
546542

547543
deps = self._get_deps(deps)
548-
output_schema = self._prepare_output_schema(output_type, model_used.profile)
544+
output_schema = self._prepare_output_schema(output_type)
549545

550546
output_type_ = output_type or self.output_type
551547

@@ -556,7 +552,7 @@ async def main():
556552

557553
output_toolset = self._output_toolset
558554
if output_schema != self._output_schema or output_validators:
559-
output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset)
555+
output_toolset = output_schema.toolset
560556
if output_toolset:
561557
output_toolset.max_retries = self._max_result_retries
562558
output_toolset.output_validators = output_validators
@@ -588,11 +584,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
588584
*[await func.run(run_context) for func in instructions_functions],
589585
]
590586

591-
model_profile = model_used.profile
592-
if isinstance(output_schema, _output.PromptedOutputSchema):
593-
instructions = output_schema.instructions(model_profile.prompted_output_template)
594-
parts.append(instructions)
595-
596587
parts = [p for p in parts if p]
597588
if not parts:
598589
return None
@@ -1409,20 +1400,16 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
14091400
return toolsets
14101401

14111402
def _prepare_output_schema(
1412-
self, output_type: OutputSpec[RunOutputDataT] | None, model_profile: ModelProfile
1413-
) -> _output.OutputSchema[RunOutputDataT]:
1403+
self, output_type: OutputSpec[RunOutputDataT] | None
1404+
) -> _output.BaseOutputSchema[RunOutputDataT]:
14141405
if output_type is not None:
14151406
if self._output_validators:
14161407
raise exceptions.UserError('Cannot set a custom run `output_type` when the agent has output validators')
1417-
schema = _output.OutputSchema[RunOutputDataT].build(
1418-
output_type, default_mode=model_profile.default_structured_output_mode
1419-
)
1408+
schema = _output.OutputSchema[RunOutputDataT].build(output_type)
14201409
else:
1421-
schema = self._output_schema.with_default_mode(model_profile.default_structured_output_mode)
1422-
1423-
schema.raise_if_unsupported(model_profile)
1410+
schema = self._output_schema
14241411

1425-
return schema # pyright: ignore[reportReturnType]
1412+
return schema
14261413

14271414
async def __aenter__(self) -> Self:
14281415
"""Enter the agent context.

0 commit comments

Comments
 (0)