Skip to content

Commit b415832

Browse files
committed
improve code coverage
1 parent 6459b9c commit b415832

File tree

2 files changed

+334
-2
lines changed

2 files changed

+334
-2
lines changed

tests/test_mcp_client.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,244 @@ async def test_get_tool_details_connection_error(self, mocker):
509509
assert "Connection failed" in result["error"]
510510

511511

512+
@pytest.mark.asyncio
513+
class TestCallTool:
514+
"""Test the call_tool function"""
515+
516+
async def test_call_tool_success_streamable_http(self, mocker):
517+
"""Test successful tool call via streamable-http"""
518+
workload = {
519+
"name": "test-server",
520+
"status": "running",
521+
"transport_type": "streamable-http",
522+
"url": "http://localhost:8080/mcp",
523+
}
524+
525+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
526+
mocker.patch(
527+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
528+
)
529+
530+
# Mock the MCP client
531+
mock_result = MagicMock()
532+
mock_result.content = [MagicMock(text="tool result")]
533+
534+
mock_session = MagicMock()
535+
mock_session.initialize = AsyncMock()
536+
mock_session.call_tool = AsyncMock(return_value=mock_result)
537+
538+
mock_client_session = MagicMock()
539+
mock_client_session.__aenter__ = AsyncMock(return_value=mock_session)
540+
mock_client_session.__aexit__ = AsyncMock()
541+
542+
mock_http = MagicMock()
543+
mock_http.__aenter__ = AsyncMock(return_value=("read", "write", lambda: None))
544+
mock_http.__aexit__ = AsyncMock()
545+
546+
mocker.patch("mcp_client.streamablehttp_client", return_value=mock_http)
547+
mocker.patch("mcp_client.ClientSession", return_value=mock_client_session)
548+
549+
result = await mcp_client.call_tool(
550+
"test-server", "test_tool", {"param": "value"}
551+
)
552+
553+
assert result == mock_result
554+
mock_session.call_tool.assert_called_once_with(
555+
"test_tool", arguments={"param": "value"}
556+
)
557+
558+
async def test_call_tool_success_sse(self, mocker):
559+
"""Test successful tool call via SSE"""
560+
workload = {
561+
"name": "test-server",
562+
"status": "running",
563+
"proxy_mode": "sse",
564+
"url": "http://localhost:8080/sse",
565+
}
566+
567+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
568+
mocker.patch(
569+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
570+
)
571+
572+
mock_result = MagicMock()
573+
574+
mock_session = MagicMock()
575+
mock_session.initialize = AsyncMock()
576+
mock_session.call_tool = AsyncMock(return_value=mock_result)
577+
578+
mock_client_session = MagicMock()
579+
mock_client_session.__aenter__ = AsyncMock(return_value=mock_session)
580+
mock_client_session.__aexit__ = AsyncMock()
581+
582+
mock_sse = MagicMock()
583+
mock_sse.__aenter__ = AsyncMock(return_value=("read", "write"))
584+
mock_sse.__aexit__ = AsyncMock()
585+
586+
mocker.patch("mcp_client.sse_client", return_value=mock_sse)
587+
mocker.patch("mcp_client.ClientSession", return_value=mock_client_session)
588+
589+
result = await mcp_client.call_tool(
590+
"test-server", "test_tool", {"param": "value"}
591+
)
592+
593+
assert result == mock_result
594+
595+
async def test_call_tool_workload_not_found(self, mocker):
596+
"""Test call_tool when workload doesn't exist"""
597+
mocker.patch("mcp_client.get_workloads", return_value=[])
598+
mocker.patch(
599+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
600+
)
601+
602+
with pytest.raises(ValueError, match="not found"):
603+
await mcp_client.call_tool("nonexistent", "test_tool", {})
604+
605+
async def test_call_tool_workload_not_running(self, mocker):
606+
"""Test call_tool when workload is not running"""
607+
workload = {
608+
"name": "test-server",
609+
"status": "stopped",
610+
"transport_type": "streamable-http",
611+
"url": "http://localhost:8080/mcp",
612+
}
613+
614+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
615+
mocker.patch(
616+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
617+
)
618+
619+
with pytest.raises(RuntimeError, match="not running"):
620+
await mcp_client.call_tool("test-server", "test_tool", {})
621+
622+
async def test_call_tool_no_url(self, mocker):
623+
"""Test call_tool when workload has no URL"""
624+
workload = {
625+
"name": "test-server",
626+
"status": "running",
627+
"transport_type": "streamable-http",
628+
"url": "",
629+
}
630+
631+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
632+
mocker.patch(
633+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
634+
)
635+
636+
with pytest.raises(ValueError, match="No URL"):
637+
await mcp_client.call_tool("test-server", "test_tool", {})
638+
639+
async def test_call_tool_unsupported_transport(self, mocker):
640+
"""Test call_tool with unsupported transport"""
641+
workload = {
642+
"name": "test-server",
643+
"status": "running",
644+
"transport_type": "stdio",
645+
"url": "http://localhost:8080/mcp",
646+
}
647+
648+
mocker.patch("mcp_client.get_workloads", return_value=[workload])
649+
mocker.patch(
650+
"toolhive_client.discover_toolhive", return_value=("localhost", 8080)
651+
)
652+
653+
with pytest.raises(ValueError, match="not supported"):
654+
await mcp_client.call_tool("test-server", "test_tool", {})
655+
656+
async def test_call_tool_discovery_fallback(self, mocker):
657+
"""Test call_tool falls back to defaults when discovery fails"""
658+
workload = {
659+
"name": "test-server",
660+
"status": "running",
661+
"transport_type": "streamable-http",
662+
"url": "http://localhost:8080/mcp",
663+
}
664+
665+
mock_get_workloads = mocker.patch(
666+
"mcp_client.get_workloads", return_value=[workload]
667+
)
668+
mocker.patch(
669+
"toolhive_client.discover_toolhive",
670+
side_effect=Exception("Discovery failed"),
671+
)
672+
673+
mock_result = MagicMock()
674+
mock_session = MagicMock()
675+
mock_session.initialize = AsyncMock()
676+
mock_session.call_tool = AsyncMock(return_value=mock_result)
677+
678+
mock_client_session = MagicMock()
679+
mock_client_session.__aenter__ = AsyncMock(return_value=mock_session)
680+
mock_client_session.__aexit__ = AsyncMock()
681+
682+
mock_http = MagicMock()
683+
mock_http.__aenter__ = AsyncMock(return_value=("read", "write", lambda: None))
684+
mock_http.__aexit__ = AsyncMock()
685+
686+
mocker.patch("mcp_client.streamablehttp_client", return_value=mock_http)
687+
mocker.patch("mcp_client.ClientSession", return_value=mock_client_session)
688+
689+
# Should not raise, should fall back to defaults
690+
result = await mcp_client.call_tool("test-server", "test_tool", {})
691+
692+
assert result == mock_result
693+
# Should have been called with default host/port
694+
mock_get_workloads.assert_called_once_with("127.0.0.1", 8080)
695+
696+
697+
@pytest.mark.asyncio
698+
class TestGetWorkloadsUrlRewriting:
699+
"""Test localhost URL rewriting for container networking"""
700+
701+
async def test_rewrites_localhost_urls(self, mocker):
702+
"""Test that localhost URLs are rewritten to use the actual host"""
703+
mock_response = MagicMock()
704+
mock_response.json.return_value = {
705+
"workloads": [
706+
{
707+
"name": "workload1",
708+
"url": "http://localhost:9000/mcp",
709+
},
710+
{
711+
"name": "workload2",
712+
"url": "http://127.0.0.1:9001/sse",
713+
},
714+
]
715+
}
716+
717+
mock_client = MagicMock()
718+
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
719+
mocker.patch("httpx.AsyncClient", return_value=mock_client)
720+
721+
# Call with a different host (simulating container environment)
722+
result = await mcp_client.get_workloads(host="192.168.1.100", port=8080)
723+
724+
# URLs should be rewritten
725+
assert result[0]["url"] == "http://192.168.1.100:9000/mcp"
726+
assert result[1]["url"] == "http://192.168.1.100:9001/sse"
727+
728+
async def test_preserves_non_localhost_urls(self, mocker):
729+
"""Test that non-localhost URLs are not rewritten"""
730+
mock_response = MagicMock()
731+
mock_response.json.return_value = {
732+
"workloads": [
733+
{
734+
"name": "workload1",
735+
"url": "http://some-service:9000/mcp",
736+
},
737+
]
738+
}
739+
740+
mock_client = MagicMock()
741+
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_response)
742+
mocker.patch("httpx.AsyncClient", return_value=mock_client)
743+
744+
result = await mcp_client.get_workloads(host="192.168.1.100", port=8080)
745+
746+
# URL should not be rewritten
747+
assert result[0]["url"] == "http://some-service:9000/mcp"
748+
749+
512750
@pytest.mark.asyncio
513751
class TestSelfFiltering:
514752
"""Test that mcp-shell filters itself out from tool listings"""

tests/test_shell_engine.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,21 @@ async def test_shell_stage_for_each_mode(self):
145145
assert "apricot" in output
146146
assert "banana" not in output
147147

148+
async def test_shell_stage_for_each_without_trailing_newline(self):
149+
"""Test shell command with for_each when input lacks trailing newline."""
150+
mock_caller = AsyncMock()
151+
engine = ShellEngine(tool_caller=mock_caller)
152+
153+
# Input without trailing newline - last line should still be processed
154+
upstream = iter(["apple\nbanana\napricot"])
155+
result = list(engine.shell_stage("grep", ["^a"], upstream, for_each=True))
156+
157+
output = "".join(result)
158+
# All lines starting with 'a' should be processed, including 'apricot'
159+
assert "apple" in output
160+
assert "apricot" in output
161+
assert "banana" not in output
162+
148163
async def test_shell_stage_empty_input(self):
149164
"""Test shell command with empty input."""
150165
mock_caller = AsyncMock()
@@ -263,6 +278,72 @@ async def test_tool_stage_result_with_no_content_attribute(self):
263278

264279
assert result == "plain string result"
265280

281+
async def test_tool_stage_for_each_without_trailing_newline(self):
282+
"""Test tool call with for_each when input doesn't have trailing newline."""
283+
mock_caller = AsyncMock(return_value=MockToolResult("result"))
284+
engine = ShellEngine(tool_caller=mock_caller)
285+
286+
# Input without trailing newline - should still process the last line
287+
jsonl_input = '{"url": "http://example.com/1"}\n{"url": "http://example.com/2"}'
288+
upstream = iter([jsonl_input])
289+
290+
result = await engine.tool_stage(
291+
"test_server", "fetch", {}, upstream, for_each=True
292+
)
293+
294+
# Should be called twice (once per line, including line without newline)
295+
assert mock_caller.call_count == 2
296+
assert "result" in result
297+
298+
async def test_tool_stage_non_dict_json_upstream(self):
299+
"""Test tool_stage with non-dict JSON upstream (array) in non-for_each mode."""
300+
mock_caller = AsyncMock(return_value=MockToolResult("result"))
301+
engine = ShellEngine(tool_caller=mock_caller)
302+
303+
# JSON array as upstream
304+
array_input = '["item1", "item2", "item3"]'
305+
upstream = iter([array_input])
306+
307+
result = await engine.tool_stage("test_server", "test_tool", {}, upstream)
308+
309+
# Should add array as 'input' field
310+
mock_caller.assert_called_once()
311+
call_args = mock_caller.call_args[0][2]
312+
assert call_args["input"] == ["item1", "item2", "item3"]
313+
314+
async def test_tool_stage_plain_text_upstream(self):
315+
"""Test tool_stage with plain text (non-JSON) upstream."""
316+
mock_caller = AsyncMock(return_value=MockToolResult("result"))
317+
engine = ShellEngine(tool_caller=mock_caller)
318+
319+
# Plain text that isn't valid JSON
320+
text_input = "some plain text data"
321+
upstream = iter([text_input])
322+
323+
result = await engine.tool_stage("test_server", "test_tool", {}, upstream)
324+
325+
# Should add text as 'input' field
326+
mock_caller.assert_called_once()
327+
call_args = mock_caller.call_args[0][2]
328+
assert call_args["input"] == "some plain text data"
329+
330+
async def test_tool_stage_non_dict_json_does_not_override_existing_input(self):
331+
"""Test that non-dict JSON doesn't override explicit 'input' arg."""
332+
mock_caller = AsyncMock(return_value=MockToolResult("result"))
333+
engine = ShellEngine(tool_caller=mock_caller)
334+
335+
array_input = '["upstream_data"]'
336+
upstream = iter([array_input])
337+
338+
result = await engine.tool_stage(
339+
"test_server", "test_tool", {"input": "explicit_input"}, upstream
340+
)
341+
342+
# Explicit input should be preserved
343+
mock_caller.assert_called_once()
344+
call_args = mock_caller.call_args[0][2]
345+
assert call_args["input"] == "explicit_input"
346+
266347

267348
@pytest.mark.asyncio
268349
class TestExecutePipeline:
@@ -362,8 +443,8 @@ async def test_execute_pipeline_missing_command_field(self):
362443
assert "Pipeline execution failed" in result
363444
assert "missing 'command' field" in result
364445

365-
async def test_execute_pipeline_missing_tool_fields(self):
366-
"""Test pipeline with missing tool fields."""
446+
async def test_execute_pipeline_missing_tool_name_field(self):
447+
"""Test pipeline with missing tool name field."""
367448
mock_caller = AsyncMock()
368449
engine = ShellEngine(tool_caller=mock_caller)
369450

@@ -375,6 +456,19 @@ async def test_execute_pipeline_missing_tool_fields(self):
375456
assert "Pipeline execution failed" in result
376457
assert "missing 'name' field" in result
377458

459+
async def test_execute_pipeline_missing_tool_server_field(self):
460+
"""Test pipeline with missing tool server field."""
461+
mock_caller = AsyncMock()
462+
engine = ShellEngine(tool_caller=mock_caller)
463+
464+
# Missing "server" field
465+
pipeline = [{"type": "tool", "name": "test_tool", "args": {}}]
466+
467+
result = await engine.execute_pipeline(pipeline)
468+
469+
assert "Pipeline execution failed" in result
470+
assert "missing 'server' field" in result
471+
378472
async def test_execute_pipeline_invalid_args_type(self):
379473
"""Test pipeline with invalid args type (not a list)."""
380474
mock_caller = AsyncMock()

0 commit comments

Comments
 (0)