Skip to content

Commit 09184e2

Browse files
committed
More test stuff
1 parent ba9433f commit 09184e2

File tree

5 files changed

+51
-39
lines changed

5 files changed

+51
-39
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def dump(self) -> JsonSchema:
420420
processor.object_def.description = tool_def.description
421421
processors.append(processor)
422422
return UnionOutputProcessor(processors).object_def.json_schema
423+
423424
return self.processor.object_def.json_schema
424425

425426

@@ -550,16 +551,16 @@ def mode(self) -> OutputMode:
550551
return 'tool'
551552

552553
def dump(self) -> JsonSchema:
553-
if self.toolset:
554-
processors: list[ObjectOutputProcessor[OutputDataT]] = []
555-
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
556-
processor = copy.copy(self.toolset.processors[tool_def.name])
557-
processor.object_def.name = tool_def.name
558-
processor.object_def.description = tool_def.description
559-
processors.append(processor)
560-
return UnionOutputProcessor(processors).object_def.json_schema
561-
else:
562-
raise RuntimeError('ToolOutputSchema has no toolset.')
554+
if self.toolset is None:
555+
# need to check expected behavior
556+
raise NotImplementedError()
557+
processors: list[ObjectOutputProcessor[OutputDataT]] = []
558+
for tool_def in self.toolset._tool_defs: # pyright: ignore [reportPrivateUsage]
559+
processor = copy.copy(self.toolset.processors[tool_def.name])
560+
processor.object_def.name = tool_def.name
561+
processor.object_def.description = tool_def.description
562+
processors.append(processor)
563+
return UnionOutputProcessor(processors).object_def.json_schema
563564

564565

565566
class BaseOutputProcessor(ABC, Generic[OutputDataT]):

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,7 @@ def decorator(
948948
self._system_prompt_functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func, dynamic=dynamic))
949949
return func
950950

951-
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
951+
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
952952
"""The output JSON schema."""
953953
output_schema = self._prepare_output_schema(output_type)
954954
return output_schema.dump()

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def toolsets(self) -> Sequence[AbstractToolset[AgentDepsT]]:
124124
raise NotImplementedError
125125

126126
@abstractmethod
127-
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
127+
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
128128
"""The output JSON schema."""
129129
raise NotImplementedError
130130

pydantic_ai_slim/pydantic_ai/agent/wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ async def __aenter__(self) -> AbstractAgent[AgentDepsT, OutputDataT]:
6868
async def __aexit__(self, *args: Any) -> bool | None:
6969
return await self.wrapped.__aexit__(*args)
7070

71-
def output_json_schema(self, output_type: OutputSpec[OutputDataT] | None = None) -> JsonSchema:
71+
def output_json_schema(self, output_type: OutputSpec[RunOutputDataT] | None = None) -> JsonSchema:
7272
return self.wrapped.output_json_schema(output_type=output_type)
7373

7474
@overload

tests/test_agent.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5155,6 +5155,35 @@ def foo() -> str:
51555155
assert wrapper_agent.name == 'wrapped'
51565156
assert wrapper_agent.output_type == agent.output_type
51575157
assert wrapper_agent.event_stream_handler == agent.event_stream_handler
5158+
assert wrapper_agent.output_json_schema() == snapshot(
5159+
{
5160+
'type': 'object',
5161+
'properties': {
5162+
'result': {
5163+
'anyOf': [
5164+
{
5165+
'type': 'object',
5166+
'properties': {
5167+
'kind': {'type': 'string', 'const': 'final_result'},
5168+
'data': {
5169+
'properties': {'a': {'type': 'integer'}, 'b': {'type': 'string'}},
5170+
'required': ['a', 'b'],
5171+
'type': 'object',
5172+
},
5173+
},
5174+
'required': ['kind', 'data'],
5175+
'additionalProperties': False,
5176+
'title': 'final_result',
5177+
'description': 'The final response which ends this conversation',
5178+
}
5179+
]
5180+
}
5181+
},
5182+
'required': ['result'],
5183+
'additionalProperties': False,
5184+
}
5185+
)
5186+
assert wrapper_agent.output_json_schema(output_type=str) == snapshot({'type': 'string'})
51585187

51595188
bar_toolset = FunctionToolset()
51605189

@@ -6151,19 +6180,13 @@ def test_message_history_cannot_start_with_model_response():
61516180

61526181

61536182
async def test_text_output_json_schema():
6154-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6155-
return ModelResponse(parts=[TextPart('')])
6156-
6157-
agent = Agent(FunctionModel(llm))
6183+
agent = Agent('test')
61586184
assert agent.output_json_schema() == snapshot({'type': 'string'})
61596185

61606186

61616187
async def test_tool_output_json_schema():
6162-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6163-
return ModelResponse(parts=[TextPart('')])
6164-
61656188
agent = Agent(
6166-
FunctionModel(llm),
6189+
'test',
61676190
output_type=[ToolOutput(bool, name='alice', description='Dreaming...')],
61686191
)
61696192
assert agent.output_json_schema() == snapshot(
@@ -6196,7 +6219,7 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
61966219
)
61976220

61986221
agent = Agent(
6199-
FunctionModel(llm),
6222+
'test',
62006223
output_type=[ToolOutput(bool, name='alice'), ToolOutput(bool, name='bob')],
62016224
)
62026225
assert agent.output_json_schema() == snapshot(
@@ -6245,11 +6268,8 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
62456268

62466269

62476270
async def test_native_output_json_schema():
6248-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6249-
return ModelResponse(parts=[TextPart('')])
6250-
62516271
agent = Agent(
6252-
FunctionModel(llm),
6272+
'test',
62536273
output_type=NativeOutput([bool], name='native_output_name', description='native_output_description'),
62546274
)
62556275
assert agent.output_json_schema() == snapshot(
@@ -6258,11 +6278,8 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
62586278

62596279

62606280
async def test_prompted_output_json_schema():
6261-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6262-
return ModelResponse(parts=[TextPart('')])
6263-
62646281
agent = Agent(
6265-
FunctionModel(llm),
6282+
'test',
62666283
output_type=PromptedOutput([bool], name='prompted_output_name', description='prompted_output_description'),
62676284
)
62686285
assert agent.output_json_schema() == snapshot(
@@ -6271,9 +6288,6 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
62716288

62726289

62736290
async def test_custom_output_json_schema():
6274-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6275-
return ModelResponse(parts=[TextPart('')])
6276-
62776291
HumanDict = StructuredDict(
62786292
{
62796293
'type': 'object',
@@ -6283,7 +6297,7 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
62836297
name='Human',
62846298
description='A human with a name and age',
62856299
)
6286-
agent = Agent(FunctionModel(llm), output_type=HumanDict)
6300+
agent = Agent('test', output_type=HumanDict)
62876301
assert agent.output_json_schema() == snapshot(
62886302
{
62896303
'type': 'object',
@@ -6315,12 +6329,9 @@ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
63156329

63166330

63176331
async def test_override_output_json_schema():
6318-
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
6319-
return ModelResponse(parts=[TextPart('')])
6320-
6321-
agent = Agent(FunctionModel(llm))
6332+
agent = Agent('test')
63226333
assert agent.output_json_schema() == snapshot({'type': 'string'})
6323-
output_type = ([ToolOutput(bool, name='alice', description='Dreaming...')],)
6334+
output_type = [ToolOutput(bool, name='alice', description='Dreaming...')]
63246335
assert agent.output_json_schema(output_type=output_type) == snapshot(
63256336
{
63266337
'type': 'object',

0 commit comments

Comments
 (0)