Skip to content

Commit 6766617

Browse files
committed
Add tests
1 parent 2091142 commit 6766617

File tree

4 files changed

+246
-8
lines changed

4 files changed

+246
-8
lines changed

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,10 @@ async def _map_tool_result_part(
233233
return self._get_content(resource)
234234
elif isinstance(part, mcp_types.ResourceLink):
235235
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
236-
return [self._get_content(resource) for resource in resource_result.contents]
237-
236+
if len(resource_result.contents) > 1:
237+
return [self._get_content(resource) for resource in resource_result.contents]
238+
else:
239+
return self._get_content(resource_result.contents[0])
238240
else:
239241
assert_never(part)
240242

tests/assets/product_name.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
PydanticAI

tests/mcp_server.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from mcp.types import (
99
BlobResourceContents,
1010
EmbeddedResource,
11+
ResourceLink,
1112
SamplingMessage,
1213
TextContent,
1314
TextResourceContents,
@@ -50,13 +51,22 @@ async def get_image_resource() -> EmbeddedResource:
5051
return EmbeddedResource(
5152
type='resource',
5253
resource=BlobResourceContents(
53-
uri='resource://kiwi.png', # type: ignore
54+
uri=AnyUrl('resource://kiwi.png'),
5455
blob=base64.b64encode(data).decode('utf-8'),
5556
mimeType='image/png',
5657
),
5758
)
5859

5960

61+
@mcp.tool()
62+
async def get_image_resource_1() -> ResourceLink:
63+
return ResourceLink(
64+
type='resource_link',
65+
uri=AnyUrl(Path(__file__).parent.joinpath('assets/kiwi.png').absolute().as_uri()),
66+
name='kiwi.png',
67+
)
68+
69+
6070
@mcp.tool()
6171
async def get_audio_resource() -> EmbeddedResource:
6272
data = Path(__file__).parent.joinpath('assets/marcelo.mp3').read_bytes()
@@ -70,17 +80,35 @@ async def get_audio_resource() -> EmbeddedResource:
7080
)
7181

7282

83+
@mcp.tool()
84+
async def get_audio_resource_1() -> ResourceLink:
85+
return ResourceLink(
86+
type='resource_link',
87+
uri=AnyUrl(Path(__file__).parent.joinpath('assets/marcelo.mp3').absolute().as_uri()),
88+
name='marcelo.mp3',
89+
)
90+
91+
7392
@mcp.tool()
7493
async def get_product_name() -> EmbeddedResource:
7594
return EmbeddedResource(
7695
type='resource',
7796
resource=TextResourceContents(
78-
uri='resource://product_name.txt', # type: ignore
97+
uri=AnyUrl('resource://product_name.txt'),
7998
text='PydanticAI',
8099
),
81100
)
82101

83102

103+
@mcp.tool()
104+
async def get_product_name_1() -> ResourceLink:
105+
return ResourceLink(
106+
type='resource_link',
107+
uri=AnyUrl(Path(__file__).parent.joinpath('assets/product_name.txt').absolute().as_uri()),
108+
name='product_name.txt',
109+
)
110+
111+
84112
@mcp.tool()
85113
async def get_image() -> Image:
86114
data = Path(__file__).parent.joinpath('assets/kiwi.png').read_bytes()

tests/test_mcp.py

Lines changed: 211 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ async def test_stdio_server():
5858
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
5959
async with server:
6060
tools = await server.list_tools()
61-
assert len(tools) == snapshot(13)
61+
assert len(tools) == snapshot(16)
6262
assert tools[0].name == 'celsius_to_fahrenheit'
6363
assert isinstance(tools[0].description, str)
6464
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
@@ -87,7 +87,7 @@ async def test_stdio_server_with_cwd():
8787
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
8888
async with server:
8989
tools = await server.list_tools()
90-
assert len(tools) == snapshot(13)
90+
assert len(tools) == snapshot(16)
9191

9292

9393
async def test_process_tool_call() -> None:
@@ -254,8 +254,8 @@ async def test_log_level_unset():
254254
assert server.log_level is None
255255
async with server:
256256
tools = await server.list_tools()
257-
assert len(tools) == snapshot(13)
258-
assert tools[10].name == 'get_log_level'
257+
assert len(tools) == snapshot(16)
258+
assert tools[13].name == 'get_log_level'
259259

260260
result = await server.call_tool('get_log_level', {})
261261
assert result == snapshot('unset')
@@ -421,6 +421,79 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A
421421
)
422422

423423

424+
@pytest.mark.vcr()
425+
async def test_tool_returning_text_resource_1(allow_model_requests: None, agent: Agent):
426+
async with agent.run_mcp_servers():
427+
result = await agent.run('Get me the product name')
428+
assert result.output == snapshot('The product name is "PydanticAI".')
429+
assert result.all_messages() == snapshot(
430+
[
431+
ModelRequest(
432+
parts=[
433+
UserPromptPart(
434+
content='Get me the product name',
435+
timestamp=IsDatetime(),
436+
)
437+
]
438+
),
439+
ModelResponse(
440+
parts=[
441+
ToolCallPart(
442+
tool_name='get_product_name_1',
443+
args='{}',
444+
tool_call_id='call_LaiWltzI39sdquflqeuF0EyE',
445+
)
446+
],
447+
usage=Usage(
448+
requests=1,
449+
request_tokens=200,
450+
response_tokens=12,
451+
total_tokens=212,
452+
details={
453+
'accepted_prediction_tokens': 0,
454+
'audio_tokens': 0,
455+
'reasoning_tokens': 0,
456+
'rejected_prediction_tokens': 0,
457+
'cached_tokens': 0,
458+
},
459+
),
460+
model_name='gpt-4o-2024-08-06',
461+
timestamp=IsDatetime(),
462+
vendor_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo',
463+
),
464+
ModelRequest(
465+
parts=[
466+
ToolReturnPart(
467+
tool_name='get_product_name_1',
468+
content='PydanticAI',
469+
tool_call_id='call_LaiWltzI39sdquflqeuF0EyE',
470+
timestamp=IsDatetime(),
471+
)
472+
]
473+
),
474+
ModelResponse(
475+
parts=[TextPart(content='The product name is "PydanticAI".')],
476+
usage=Usage(
477+
requests=1,
478+
request_tokens=224,
479+
response_tokens=12,
480+
total_tokens=236,
481+
details={
482+
'accepted_prediction_tokens': 0,
483+
'audio_tokens': 0,
484+
'reasoning_tokens': 0,
485+
'rejected_prediction_tokens': 0,
486+
'cached_tokens': 0,
487+
},
488+
),
489+
model_name='gpt-4o-2024-08-06',
490+
timestamp=IsDatetime(),
491+
vendor_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR',
492+
),
493+
]
494+
)
495+
496+
424497
@pytest.mark.vcr()
425498
async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent):
426499
async with agent.run_mcp_servers():
@@ -501,6 +574,86 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent:
501574
)
502575

503576

577+
@pytest.mark.vcr()
578+
async def test_tool_returning_image_resource_1(allow_model_requests: None, agent: Agent, image_content: BinaryContent):
579+
async with agent.run_mcp_servers():
580+
result = await agent.run('Get me the image resource')
581+
assert result.output == snapshot(
582+
'This is an image of a sliced kiwi with a vibrant green interior and black seeds.'
583+
)
584+
assert result.all_messages() == snapshot(
585+
[
586+
ModelRequest(
587+
parts=[
588+
UserPromptPart(
589+
content='Get me the image resource',
590+
timestamp=IsDatetime(),
591+
)
592+
]
593+
),
594+
ModelResponse(
595+
parts=[
596+
ToolCallPart(
597+
tool_name='get_image_resource_1',
598+
args='{}',
599+
tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt',
600+
)
601+
],
602+
usage=Usage(
603+
requests=1,
604+
request_tokens=191,
605+
response_tokens=12,
606+
total_tokens=203,
607+
details={
608+
'accepted_prediction_tokens': 0,
609+
'audio_tokens': 0,
610+
'reasoning_tokens': 0,
611+
'rejected_prediction_tokens': 0,
612+
'cached_tokens': 0,
613+
},
614+
),
615+
model_name='gpt-4o-2024-08-06',
616+
timestamp=IsDatetime(),
617+
vendor_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb',
618+
),
619+
ModelRequest(
620+
parts=[
621+
ToolReturnPart(
622+
tool_name='get_image_resource_1',
623+
content='See file 1c8566',
624+
tool_call_id='call_nFsDHYDZigO0rOHqmChZ3pmt',
625+
timestamp=IsDatetime(),
626+
),
627+
UserPromptPart(content=['This is file 1c8566:', image_content], timestamp=IsDatetime()),
628+
]
629+
),
630+
ModelResponse(
631+
parts=[
632+
TextPart(
633+
content='This is an image of a sliced kiwi with a vibrant green interior and black seeds.'
634+
)
635+
],
636+
usage=Usage(
637+
requests=1,
638+
request_tokens=1332,
639+
response_tokens=19,
640+
total_tokens=1351,
641+
details={
642+
'accepted_prediction_tokens': 0,
643+
'audio_tokens': 0,
644+
'reasoning_tokens': 0,
645+
'rejected_prediction_tokens': 0,
646+
'cached_tokens': 0,
647+
},
648+
),
649+
model_name='gpt-4o-2024-08-06',
650+
timestamp=IsDatetime(),
651+
vendor_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9',
652+
),
653+
]
654+
)
655+
656+
504657
@pytest.mark.vcr()
505658
async def test_tool_returning_audio_resource(
506659
allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str
@@ -555,6 +708,60 @@ async def test_tool_returning_audio_resource(
555708
)
556709

557710

711+
@pytest.mark.vcr()
712+
async def test_tool_returning_audio_resource_1(
713+
allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str
714+
):
715+
model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key))
716+
async with agent.run_mcp_servers():
717+
result = await agent.run("What's the content of the audio resource?", model=model)
718+
assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."')
719+
assert result.all_messages() == snapshot(
720+
[
721+
ModelRequest(
722+
parts=[UserPromptPart(content="What's the content of the audio resource?", timestamp=IsDatetime())]
723+
),
724+
ModelResponse(
725+
parts=[ToolCallPart(tool_name='get_audio_resource_1', args={}, tool_call_id=IsStr())],
726+
usage=Usage(
727+
requests=1,
728+
request_tokens=383,
729+
response_tokens=12,
730+
total_tokens=520,
731+
details={'thoughts_tokens': 125, 'text_prompt_tokens': 383},
732+
),
733+
model_name='models/gemini-2.5-pro-preview-05-06',
734+
timestamp=IsDatetime(),
735+
vendor_details={'finish_reason': 'STOP'},
736+
),
737+
ModelRequest(
738+
parts=[
739+
ToolReturnPart(
740+
tool_name='get_audio_resource_1',
741+
content='See file 2d36ae',
742+
tool_call_id=IsStr(),
743+
timestamp=IsDatetime(),
744+
),
745+
UserPromptPart(content=['This is file 2d36ae:', audio_content], timestamp=IsDatetime()),
746+
]
747+
),
748+
ModelResponse(
749+
parts=[TextPart(content='The audio resource contains a voice saying "Hello, my name is Marcelo."')],
750+
usage=Usage(
751+
requests=1,
752+
request_tokens=575,
753+
response_tokens=15,
754+
total_tokens=590,
755+
details={'text_prompt_tokens': 431, 'audio_prompt_tokens': 144},
756+
),
757+
model_name='models/gemini-2.5-pro-preview-05-06',
758+
timestamp=IsDatetime(),
759+
vendor_details={'finish_reason': 'STOP'},
760+
),
761+
]
762+
)
763+
764+
558765
@pytest.mark.vcr()
559766
async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent):
560767
async with agent.run_mcp_servers():

0 commit comments

Comments
 (0)