Skip to content

Commit 4d5e379

Browse files
committed
error handling fixes
1 parent b415832 commit 4d5e379

File tree

2 files changed

+40
-43
lines changed

2 files changed

+40
-43
lines changed

shell_engine.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,6 @@ async def execute_pipeline(self, pipeline: list[dict]) -> str:
476476
return output
477477

478478
except Exception as e:
479-
import traceback
480-
481-
error_details = f"Pipeline execution failed: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
482-
return error_details
479+
# Re-raise so MCP layer sets isError=True in the response
480+
# This ensures clients properly display/handle the error
481+
raise RuntimeError(f"Pipeline execution failed: {str(e)}") from e

tests/test_shell_engine.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ async def test_tool_stage_non_dict_json_upstream(self):
304304
array_input = '["item1", "item2", "item3"]'
305305
upstream = iter([array_input])
306306

307-
result = await engine.tool_stage("test_server", "test_tool", {}, upstream)
307+
await engine.tool_stage("test_server", "test_tool", {}, upstream)
308308

309309
# Should add array as 'input' field
310310
mock_caller.assert_called_once()
@@ -320,7 +320,7 @@ async def test_tool_stage_plain_text_upstream(self):
320320
text_input = "some plain text data"
321321
upstream = iter([text_input])
322322

323-
result = await engine.tool_stage("test_server", "test_tool", {}, upstream)
323+
await engine.tool_stage("test_server", "test_tool", {}, upstream)
324324

325325
# Should add text as 'input' field
326326
mock_caller.assert_called_once()
@@ -335,7 +335,7 @@ async def test_tool_stage_non_dict_json_does_not_override_existing_input(self):
335335
array_input = '["upstream_data"]'
336336
upstream = iter([array_input])
337337

338-
result = await engine.tool_stage(
338+
await engine.tool_stage(
339339
"test_server", "test_tool", {"input": "explicit_input"}, upstream
340340
)
341341

@@ -424,10 +424,10 @@ async def test_execute_pipeline_invalid_command(self):
424424

425425
pipeline = [{"type": "command", "command": "rm", "args": ["-rf", "/"]}]
426426

427-
result = await engine.execute_pipeline(pipeline)
428-
429-
assert "Pipeline execution failed" in result
430-
assert "not allowed" in result
427+
with pytest.raises(
428+
RuntimeError, match="Pipeline execution failed.*not allowed"
429+
):
430+
await engine.execute_pipeline(pipeline)
431431

432432
async def test_execute_pipeline_missing_command_field(self):
433433
"""Test pipeline with missing command field."""
@@ -438,10 +438,10 @@ async def test_execute_pipeline_missing_command_field(self):
438438
{"type": "command", "args": ["test"]} # Missing "command" field
439439
]
440440

441-
result = await engine.execute_pipeline(pipeline)
442-
443-
assert "Pipeline execution failed" in result
444-
assert "missing 'command' field" in result
441+
with pytest.raises(
442+
RuntimeError, match="Pipeline execution failed.*missing 'command' field"
443+
):
444+
await engine.execute_pipeline(pipeline)
445445

446446
async def test_execute_pipeline_missing_tool_name_field(self):
447447
"""Test pipeline with missing tool name field."""
@@ -451,10 +451,10 @@ async def test_execute_pipeline_missing_tool_name_field(self):
451451
# Missing "name" field
452452
pipeline = [{"type": "tool", "server": "test_server", "args": {}}]
453453

454-
result = await engine.execute_pipeline(pipeline)
455-
456-
assert "Pipeline execution failed" in result
457-
assert "missing 'name' field" in result
454+
with pytest.raises(
455+
RuntimeError, match="Pipeline execution failed.*missing 'name' field"
456+
):
457+
await engine.execute_pipeline(pipeline)
458458

459459
async def test_execute_pipeline_missing_tool_server_field(self):
460460
"""Test pipeline with missing tool server field."""
@@ -464,10 +464,10 @@ async def test_execute_pipeline_missing_tool_server_field(self):
464464
# Missing "server" field
465465
pipeline = [{"type": "tool", "name": "test_tool", "args": {}}]
466466

467-
result = await engine.execute_pipeline(pipeline)
468-
469-
assert "Pipeline execution failed" in result
470-
assert "missing 'server' field" in result
467+
with pytest.raises(
468+
RuntimeError, match="Pipeline execution failed.*missing 'server' field"
469+
):
470+
await engine.execute_pipeline(pipeline)
471471

472472
async def test_execute_pipeline_invalid_args_type(self):
473473
"""Test pipeline with invalid args type (not a list)."""
@@ -476,10 +476,10 @@ async def test_execute_pipeline_invalid_args_type(self):
476476

477477
pipeline = [{"type": "command", "command": "grep", "args": "not-a-list"}]
478478

479-
result = await engine.execute_pipeline(pipeline)
480-
481-
assert "Pipeline execution failed" in result
482-
assert "must be an array" in result
479+
with pytest.raises(
480+
RuntimeError, match="Pipeline execution failed.*must be an array"
481+
):
482+
await engine.execute_pipeline(pipeline)
483483

484484
async def test_execute_pipeline_unknown_stage_type(self):
485485
"""Test pipeline with unknown stage type."""
@@ -488,10 +488,10 @@ async def test_execute_pipeline_unknown_stage_type(self):
488488

489489
pipeline = [{"type": "unknown_type", "data": "test"}]
490490

491-
result = await engine.execute_pipeline(pipeline)
492-
493-
assert "Pipeline execution failed" in result
494-
assert "Unknown pipeline item type" in result
491+
with pytest.raises(
492+
RuntimeError, match="Pipeline execution failed.*Unknown pipeline item type"
493+
):
494+
await engine.execute_pipeline(pipeline)
495495

496496
async def test_execute_pipeline_empty_pipeline(self):
497497
"""Test pipeline with no stages returns empty string."""
@@ -548,10 +548,10 @@ async def failing_caller(server, tool, args):
548548

549549
pipeline = [{"type": "tool", "name": "test", "server": "test", "args": {}}]
550550

551-
result = await engine.execute_pipeline(pipeline)
552-
553-
assert "Pipeline execution failed" in result
554-
assert "Tool call failed" in result
551+
with pytest.raises(
552+
RuntimeError, match="Pipeline execution failed.*Tool call failed"
553+
):
554+
await engine.execute_pipeline(pipeline)
555555

556556
async def test_stage_error_includes_stage_number(self):
557557
"""Test that errors include the stage number."""
@@ -568,10 +568,8 @@ async def test_stage_error_includes_stage_number(self):
568568
{"type": "command", "command": "grep", "args": ["never reached"]},
569569
]
570570

571-
result = await engine.execute_pipeline(pipeline)
572-
573-
assert "Pipeline execution failed" in result
574-
assert "Stage 2" in result
571+
with pytest.raises(RuntimeError, match="Pipeline execution failed.*Stage 2"):
572+
await engine.execute_pipeline(pipeline)
575573

576574

577575
@pytest.mark.asyncio
@@ -640,11 +638,10 @@ async def test_execute_pipeline_command_timeout(self):
640638
]
641639

642640
start = time.time()
643-
result = await engine.execute_pipeline(pipeline)
641+
with pytest.raises(RuntimeError, match="Pipeline execution failed.*timed out"):
642+
await engine.execute_pipeline(pipeline)
644643
elapsed = time.time() - start
645644

646-
# Should fail with timeout error
647-
assert "timeout" in result.lower() or "timed out" in result.lower()
648645
assert elapsed < 0.5, f"Timeout took too long: {elapsed} seconds"
649646

650647
@pytest.mark.timeout(2)
@@ -666,7 +663,8 @@ async def test_execute_pipeline_for_each_with_timeout(self):
666663
]
667664

668665
start = time.time()
669-
await engine.execute_pipeline(pipeline)
666+
with pytest.raises(RuntimeError, match="Pipeline execution failed.*timed out"):
667+
await engine.execute_pipeline(pipeline)
670668
elapsed = time.time() - start
671669

672670
# Should timeout quickly, not wait 20 seconds (10s × 2 lines)

0 commit comments

Comments
 (0)