Skip to content

Commit e6bc181

Browse files
committed
Small refactor to address comments
1 parent 5283df9 commit e6bc181

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -962,29 +962,30 @@ def output_json_schema(self, output_type: OutputSpec[OutputDataT | RunOutputData
962962
if output_type is None:
963963
output_type = self.output_type
964964

965-
# call this first to force output_type to an iterable
965+
# forces output_type to an iterable
966966
output_type = list(_flatten_output_spec(output_type))
967967

968-
# flatten special outputs
969-
for i, _ in enumerate(output_type):
970-
if isinstance(_, _output.NativeOutput):
971-
output_type[i] = _flatten_output_spec(_.outputs)
972-
if isinstance(_, _output.PromptedOutput):
973-
output_type[i] = _flatten_output_spec(_.outputs)
974-
if isinstance(_, _output.ToolOutput):
975-
output_type[i] = _flatten_output_spec(_.output)
976-
977-
# final flattening
978-
output_type = _flatten_output_spec(output_type)
968+
flat_output_type: list[OutputSpec[OutputDataT | RunOutputDataT] | type[str]] = []
969+
for output_spec in output_type:
970+
if isinstance(output_spec, _output.NativeOutput):
971+
flat_output_type += _flatten_output_spec(output_spec.outputs)
972+
elif isinstance(output_spec, _output.PromptedOutput):
973+
flat_output_type += _flatten_output_spec(output_spec.outputs)
974+
elif isinstance(output_spec, _output.TextOutput):
975+
flat_output_type.append(str)
976+
elif isinstance(output_spec, _output.ToolOutput):
977+
flat_output_type += _flatten_output_spec(output_spec.output)
978+
else:
979+
flat_output_type.append(output_spec)
979980

980981
json_schemas: list[JsonSchema] = []
981-
for _ in output_type:
982-
if inspect.isfunction(_) or inspect.ismethod(_):
983-
json_schema = TypeAdapter(inspect.signature(_).return_annotation).json_schema(mode='serialization')
984-
elif isinstance(_, _output.TextOutput):
985-
json_schema = TypeAdapter(str).json_schema(mode='serialization')
982+
for output_spec in flat_output_type:
983+
if inspect.isfunction(output_spec) or inspect.ismethod(output_spec):
984+
json_schema = TypeAdapter(inspect.signature(output_spec).return_annotation).json_schema(
985+
mode='serialization'
986+
)
986987
else:
987-
json_schema = TypeAdapter(_).json_schema(mode='serialization')
988+
json_schema = TypeAdapter(output_spec).json_schema(mode='serialization')
988989

989990
if json_schema not in json_schemas:
990991
json_schemas.append(json_schema)

0 commit comments

Comments
 (0)