|
1 | 1 | """End-to-end integration tests with mocked ComfyUI backend. |
2 | 2 |
|
3 | | -These tests wire up the full server stack (config -> security -> tools -> client) |
| 3 | +These tests wire up tool registration directly (via register_*_tools return dicts) |
4 | 4 | and exercise tools through the same code paths used in production. |
5 | 5 | """ |
6 | 6 |
|
7 | 7 | import json |
8 | 8 |
|
9 | 9 | import httpx |
10 | 10 | import respx |
| 11 | +from mcp.server.fastmcp import FastMCP |
11 | 12 |
|
12 | | -from comfyui_mcp.config import ComfyUISettings, SecuritySettings, Settings |
13 | | -from comfyui_mcp.security.inspector import WorkflowBlockedError |
14 | | -from comfyui_mcp.server import _build_server |
| 13 | +from comfyui_mcp.audit import AuditLogger |
| 14 | +from comfyui_mcp.client import ComfyUIClient |
| 15 | +from comfyui_mcp.security.inspector import WorkflowBlockedError, WorkflowInspector |
| 16 | +from comfyui_mcp.security.rate_limit import RateLimiter |
| 17 | +from comfyui_mcp.security.sanitizer import PathSanitizer |
| 18 | +from comfyui_mcp.tools.discovery import register_discovery_tools |
| 19 | +from comfyui_mcp.tools.generation import register_generation_tools |
| 20 | +from comfyui_mcp.tools.jobs import register_job_tools |
15 | 21 |
|
16 | 22 |
|
17 | 23 | class TestImageGenerationFlow: |
18 | 24 | @respx.mock |
19 | | - async def test_generate_image_lists_models_then_generates(self): |
| 25 | + async def test_generate_image_lists_models_then_generates(self, tmp_path): |
20 | 26 | """Full flow: list models -> generate image -> check job.""" |
21 | 27 | respx.get("http://mock-comfyui:8188/models/checkpoints").mock( |
22 | 28 | return_value=httpx.Response(200, json=["sd_v15.safetensors"]) |
@@ -47,58 +53,66 @@ async def test_generate_image_lists_models_then_generates(self): |
47 | 53 | ) |
48 | 54 | ) |
49 | 55 |
|
50 | | - settings = Settings(comfyui=ComfyUISettings(url="http://mock-comfyui:8188")) |
51 | | - server, _ = _build_server(settings) |
| 56 | + client = ComfyUIClient(base_url="http://mock-comfyui:8188") |
| 57 | + audit = AuditLogger(audit_file=tmp_path / "audit.log") |
| 58 | + limiter = RateLimiter(max_per_minute=60) |
| 59 | + inspector = WorkflowInspector(mode="audit", dangerous_nodes=[], allowed_nodes=[]) |
| 60 | + sanitizer = PathSanitizer(allowed_extensions=[".png", ".jpg", ".jpeg", ".webp"]) |
| 61 | + mcp = FastMCP("test") |
| 62 | + |
| 63 | + discovery_tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) |
| 64 | + gen_tools = register_generation_tools(mcp, client, audit, limiter, inspector) |
| 65 | + job_tools = register_job_tools(mcp, client, audit, limiter) |
52 | 66 |
|
53 | 67 | # Step 1: Discover available models |
54 | | - tools = server._tool_manager._tools |
55 | | - list_models_fn = tools["list_models"].fn |
56 | | - models = await list_models_fn(folder="checkpoints") |
| 68 | + models = await discovery_tools["list_models"](folder="checkpoints") |
57 | 69 | assert "sd_v15.safetensors" in models |
58 | 70 |
|
59 | 71 | # Step 2: Generate an image |
60 | | - generate_fn = tools["generate_image"].fn |
61 | | - result = await generate_fn(prompt="a sunset over mountains") |
| 72 | + result = await gen_tools["generate_image"](prompt="a sunset over mountains") |
62 | 73 | assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result |
63 | 74 |
|
64 | 75 | # Step 3: Check the job |
65 | | - get_job_fn = tools["get_job"].fn |
66 | | - job = await get_job_fn(prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") |
| 76 | + job = await job_tools["get_job"](prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") |
67 | 77 | assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in job |
68 | 78 |
|
69 | 79 | @respx.mock |
70 | | - async def test_run_workflow_with_dangerous_node_in_audit_mode(self): |
| 80 | + async def test_run_workflow_with_dangerous_node_in_audit_mode(self, tmp_path): |
71 | 81 | """Audit mode logs dangerous nodes but still submits the workflow.""" |
72 | 82 | respx.post("http://mock-comfyui:8188/prompt").mock( |
73 | 83 | return_value=httpx.Response( |
74 | 84 | 200, json={"prompt_id": "11111111-2222-3333-4444-555555555555"} |
75 | 85 | ) |
76 | 86 | ) |
77 | 87 |
|
78 | | - settings = Settings(comfyui=ComfyUISettings(url="http://mock-comfyui:8188")) |
79 | | - server, _ = _build_server(settings) |
| 88 | + client = ComfyUIClient(base_url="http://mock-comfyui:8188") |
| 89 | + audit = AuditLogger(audit_file=tmp_path / "audit.log") |
| 90 | + limiter = RateLimiter(max_per_minute=60) |
| 91 | + inspector = WorkflowInspector(mode="audit", dangerous_nodes=["Terminal"], allowed_nodes=[]) |
| 92 | + mcp = FastMCP("test") |
80 | 93 |
|
81 | | - run_workflow_fn = server._tool_manager._tools["run_workflow"].fn |
| 94 | + tools = register_generation_tools(mcp, client, audit, limiter, inspector) |
82 | 95 | workflow = json.dumps({"1": {"class_type": "Terminal", "inputs": {}}}) |
83 | | - result = await run_workflow_fn(workflow=workflow) |
| 96 | + result = await tools["run_workflow"](workflow=workflow) |
84 | 97 | assert "11111111-2222-3333-4444-555555555555" in result |
85 | 98 | assert "Terminal" in result |
86 | 99 |
|
87 | | - async def test_run_workflow_blocked_in_enforce_mode(self): |
| 100 | + async def test_run_workflow_blocked_in_enforce_mode(self, tmp_path): |
88 | 101 | """Enforce mode blocks workflows with unapproved nodes.""" |
89 | | - settings = Settings( |
90 | | - comfyui=ComfyUISettings(url="http://mock-comfyui:8188"), |
91 | | - security=SecuritySettings( |
92 | | - mode="enforce", |
93 | | - allowed_nodes=["KSampler", "CLIPTextEncode"], |
94 | | - ), |
| 102 | + client = ComfyUIClient(base_url="http://mock-comfyui:8188") |
| 103 | + audit = AuditLogger(audit_file=tmp_path / "audit.log") |
| 104 | + limiter = RateLimiter(max_per_minute=60) |
| 105 | + inspector = WorkflowInspector( |
| 106 | + mode="enforce", |
| 107 | + dangerous_nodes=[], |
| 108 | + allowed_nodes=["KSampler", "CLIPTextEncode"], |
95 | 109 | ) |
96 | | - server, _ = _build_server(settings) |
| 110 | + mcp = FastMCP("test") |
97 | 111 |
|
98 | | - run_workflow_fn = server._tool_manager._tools["run_workflow"].fn |
| 112 | + tools = register_generation_tools(mcp, client, audit, limiter, inspector) |
99 | 113 | workflow = json.dumps({"1": {"class_type": "MaliciousNode", "inputs": {}}}) |
100 | 114 | try: |
101 | | - await run_workflow_fn(workflow=workflow) |
| 115 | + await tools["run_workflow"](workflow=workflow) |
102 | 116 | raise AssertionError("Should have raised WorkflowBlockedError") |
103 | 117 | except WorkflowBlockedError as e: |
104 | 118 | assert "MaliciousNode" in str(e) |
0 commit comments