Skip to content

Commit 441d6a0

Browse files
committed
use snapshots where it makes sense
1 parent 91f9533 commit 441d6a0

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

tests/test_cli.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,19 @@ def append_to_clipboard(text: str) -> None:
232232
mocker.patch('pyperclip.copy', append_to_clipboard)
233233
assert handle_slash_command('/cp', [], False, Console(file=io), 'default') == (None, False)
234234
assert io.getvalue() == snapshot('No output available to copy.\n')
235-
assert len(mock_clipboard) == 0
235+
assert mock_clipboard == snapshot([])
236236

237237
messages: list[ModelMessage] = [ModelResponse(parts=[TextPart(''), ToolCallPart('foo', '{}')])]
238238
io = StringIO()
239239
assert handle_slash_command('/cp', messages, True, Console(file=io), 'default') == (None, True)
240240
assert io.getvalue() == snapshot('No text content to copy.\n')
241-
assert len(mock_clipboard) == 0
241+
assert mock_clipboard == snapshot([])
242242

243243
messages: list[ModelMessage] = [ModelResponse(parts=[TextPart('hello'), ToolCallPart('foo', '{}')])]
244244
io = StringIO()
245245
assert handle_slash_command('/cp', messages, True, Console(file=io), 'default') == (None, True)
246246
assert io.getvalue() == snapshot('Copied last output to clipboard.\n')
247-
assert len(mock_clipboard) == 1
248-
assert mock_clipboard[0] == snapshot('hello')
247+
assert mock_clipboard == snapshot(['hello'])
249248

250249

251250
def test_handle_slash_command_exit():

tests/test_ui_web.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from starlette.applications import Starlette
1818
from starlette.testclient import TestClient
1919

20-
from pydantic_ai.builtin_tools import WebSearchTool
20+
from pydantic_ai.builtin_tools import MCPServerTool, WebSearchTool
2121
from pydantic_ai.ui.web import create_web_app
2222

2323

@@ -287,7 +287,6 @@ def test_expand_env_vars_passthrough():
287287

288288
def test_load_mcp_server_tools_basic(tmp_path: Path):
289289
"""Test loading MCP server tools from a config file."""
290-
from pydantic_ai.builtin_tools import MCPServerTool
291290
from pydantic_ai.ui.web._mcp import load_mcp_server_tools
292291

293292
config = {
@@ -301,10 +300,7 @@ def test_load_mcp_server_tools_basic(tmp_path: Path):
301300
config_file.write_text(json.dumps(config), encoding='utf-8')
302301

303302
tools = load_mcp_server_tools(str(config_file))
304-
assert len(tools) == 1
305-
assert isinstance(tools[0], MCPServerTool)
306-
assert tools[0].id == 'test-server'
307-
assert tools[0].url == 'https://example.com/mcp'
303+
assert tools == snapshot([MCPServerTool(id='test-server', url='https://example.com/mcp')])
308304

309305

310306
def test_load_mcp_server_tools_with_all_fields(tmp_path: Path, monkeypatch: pytest.MonkeyPatch):
@@ -328,13 +324,18 @@ def test_load_mcp_server_tools_with_all_fields(tmp_path: Path, monkeypatch: pyte
328324
config_file.write_text(json.dumps(config), encoding='utf-8')
329325

330326
tools = load_mcp_server_tools(str(config_file))
331-
assert len(tools) == 1
332-
assert tools[0].id == 'full-server'
333-
assert tools[0].url == 'https://example.com/mcp'
334-
assert tools[0].authorization_token == 'my-secret-token'
335-
assert tools[0].description == 'A test MCP server'
336-
assert tools[0].allowed_tools == ['tool1', 'tool2']
337-
assert tools[0].headers == {'X-Custom': 'header-value'}
327+
assert tools == snapshot(
328+
[
329+
MCPServerTool(
330+
id='full-server',
331+
url='https://example.com/mcp',
332+
authorization_token='my-secret-token',
333+
description='A test MCP server',
334+
allowed_tools=['tool1', 'tool2'],
335+
headers={'X-Custom': 'header-value'},
336+
)
337+
]
338+
)
338339

339340

340341
def test_load_mcp_server_tools_file_not_found():
@@ -359,9 +360,12 @@ def test_load_mcp_server_tools_multiple_servers(tmp_path: Path):
359360
config_file.write_text(json.dumps(config), encoding='utf-8')
360361

361362
tools = load_mcp_server_tools(str(config_file))
362-
assert len(tools) == 2
363-
ids = {t.id for t in tools}
364-
assert ids == {'server-a', 'server-b'}
363+
assert tools == snapshot(
364+
[
365+
MCPServerTool(id='server-a', url='https://a.example.com/mcp'),
366+
MCPServerTool(id='server-b', url='https://b.example.com/mcp'),
367+
]
368+
)
365369

366370

367371
def test_mcp_server_tool_label():

0 commit comments

Comments
 (0)