diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e81727d..3595b53 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,10 +10,10 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4 - name: Set up Python run: uv python install 3.12 @@ -30,10 +30,10 @@ jobs: type-check: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4 - name: Set up Python run: uv python install 3.12 @@ -47,10 +47,10 @@ jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Install uv - uses: astral-sh/setup-uv@v4 + uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4 - name: Set up Python run: uv python install 3.12 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index d844381..c83e02b 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -17,16 +17,16 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3 - name: Log in to Container Registry - uses: docker/login-action@v3 + uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3 with: registry: ghcr.io username: ${{ github.actor }} @@ -34,7 +34,7 @@ jobs: - name: Extract metadata id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5 with: images: ghcr.io/${{ github.repository }} tags: | @@ -46,7 +46,7 @@ jobs: type=sha,prefix= - name: Build and push - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5 with: context: . push: true diff --git a/Dockerfile b/Dockerfile index 7673b70..3cfd2a4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,12 +2,16 @@ FROM python:3.12-slim WORKDIR /app +RUN groupadd --system app && useradd --system --gid app app + RUN pip install --no-cache-dir uv -COPY pyproject.toml uv.lock README.md ./ +COPY --chown=app:app pyproject.toml uv.lock README.md ./ RUN uv sync --frozen --no-dev -COPY src/ ./src/ +COPY --chown=app:app src/ ./src/ + +USER app ENV PYTHONPATH=/app diff --git a/src/comfyui_mcp/client.py b/src/comfyui_mcp/client.py index 0716ece..41a1e3d 100644 --- a/src/comfyui_mcp/client.py +++ b/src/comfyui_mcp/client.py @@ -3,9 +3,30 @@ from __future__ import annotations import asyncio +import re import httpx +_ALLOWED_HTTP_METHODS = frozenset({"GET", "POST", "PUT", "DELETE", "PATCH"}) +_UUID_RE = re.compile(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$") +_SAFE_SEGMENT_RE = re.compile(r"^[A-Za-z0-9_.\-]+$") + + +def _validate_prompt_id(prompt_id: str) -> str: + """Validate that a prompt_id is a well-formed UUID.""" + if not _UUID_RE.match(prompt_id): + raise ValueError(f"Invalid prompt_id format: {prompt_id!r}") + return prompt_id + + +def _validate_path_segment(value: str, *, label: str = "value") -> str: + """Validate that a value is safe for interpolation into a URL path segment.""" + if not value: + raise ValueError(f"{label} must not be empty") + if not _SAFE_SEGMENT_RE.match(value): + raise ValueError(f"{label} contains invalid characters: {value!r}") + return value + class ComfyUIClient: def __init__( @@ -38,11 +59,14 @@ async def _get_client(self) -> httpx.AsyncClient: async def _request(self, method: str, path: str, **kwargs) -> httpx.Response: """Make an HTTP request with retry logic for transient failures.""" + normalized = method.upper() + if normalized not in _ALLOWED_HTTP_METHODS: + raise ValueError(f"HTTP method not allowed: {method!r}") last_exception: Exception | None = None for attempt in range(self._max_retries): try: c = await self._get_client() - r = await getattr(c, method)(path, **kwargs) + r = await c.request(normalized, path, **kwargs) r.raise_for_status() return r except httpx.HTTPStatusError: @@ -88,6 +112,8 @@ async def get_models(self, folder: str) -> list: return r.json() async def get_object_info(self, node_class: str | None = None) -> dict: + if node_class is not None: + _validate_path_segment(node_class, label="node_class") path = f"/object_info/{node_class}" if node_class else "/object_info" r = await self._request("get", path) return r.json() @@ -97,6 +123,7 @@ async def get_history(self) -> dict: return r.json() async def get_history_item(self, prompt_id: str) -> dict: + _validate_prompt_id(prompt_id) r = await self._request("get", f"/history/{prompt_id}") return r.json() @@ -104,6 +131,7 @@ async def interrupt(self) -> None: await self._request("post", "/interrupt") async def delete_queue_item(self, prompt_id: str) -> None: + _validate_prompt_id(prompt_id) await self._request("post", "/queue", json={"delete": [prompt_id]}) async def upload_image(self, data: bytes, filename: str, subfolder: str = "") -> dict: @@ -228,6 +256,7 @@ async def get_download_tasks(self) -> list[dict]: async def delete_download_task(self, task_id: str) -> dict: """DELETE /model-manager/download/{task_id} — cancel and remove a download.""" + _validate_path_segment(task_id, label="task_id") r = await self._request("delete", f"/model-manager/download/{task_id}") payload = self._unwrap_model_manager_response(r.json()) if isinstance(payload, dict): diff --git a/src/comfyui_mcp/config.py b/src/comfyui_mcp/config.py index d388cc0..d0ef76f 100644 --- a/src/comfyui_mcp/config.py +++ b/src/comfyui_mcp/config.py @@ -105,6 +105,19 @@ class SSESettings(BaseModel): host: str = "127.0.0.1" port: int = 8080 + @field_validator("host") + @classmethod + def warn_non_localhost(cls, v: str) -> str: + if v not in ("127.0.0.1", "::1", "localhost"): + import logging + + logging.getLogger("comfyui_mcp.config").warning( + "SSE transport binding to %s — this exposes all MCP tools without " + "authentication. Only use behind a reverse proxy with auth and TLS.", + v, + ) + return v + class TransportSettings(BaseModel): sse: SSESettings = SSESettings() diff --git a/src/comfyui_mcp/workflow/validation.py b/src/comfyui_mcp/workflow/validation.py index be852f6..712ec6c 100644 --- a/src/comfyui_mcp/workflow/validation.py +++ b/src/comfyui_mcp/workflow/validation.py @@ -182,13 +182,11 @@ async def validate_workflow( inputs = node_data.get("inputs") if inputs is None: errors.append(f"Node '{node_id}': missing 'inputs'") - node_data["inputs"] = {} continue if not isinstance(inputs, dict): errors.append( f"Node '{node_id}': 'inputs' must be an object, got {type(inputs).__name__}" ) - node_data["inputs"] = {} continue for input_name, value in inputs.items(): if isinstance(value, list) and len(value) == 2 and isinstance(value[0], str): diff --git a/tests/test_audit.py b/tests/test_audit.py index f4eaf72..978fcd2 100644 --- a/tests/test_audit.py +++ b/tests/test_audit.py @@ -16,7 +16,7 @@ def test_record_serializes_to_json(self): record = AuditRecord( tool="run_workflow", action="submitted", - prompt_id="abc-123", + prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", nodes_used=["KSampler", "CLIPTextEncode"], warnings=["Dangerous node: EvalNode"], ) diff --git a/tests/test_client.py b/tests/test_client.py index 142fdcb..ce8e5d1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -31,10 +31,12 @@ async def test_get_queue(self, client): @respx.mock async def test_post_prompt(self, client): respx.post("http://test-comfyui:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "abc-123"}) + return_value=httpx.Response( + 200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + ) ) result = await client.post_prompt({"1": {"class_type": "KSampler", "inputs": {}}}) - assert result["prompt_id"] == "abc-123" + assert result["prompt_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" @respx.mock async def test_get_models(self, client): @@ -91,15 +93,17 @@ async def test_get_image(self, client): @respx.mock async def test_delete_queue_item(self, client): respx.post("http://test-comfyui:8188/queue").mock(return_value=httpx.Response(200, json={})) - await client.delete_queue_item("abc-123") + await client.delete_queue_item("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") @respx.mock async def test_get_history_item(self, client): - respx.get("http://test-comfyui:8188/history/abc-123").mock( - return_value=httpx.Response(200, json={"abc-123": {"outputs": {}}}) + respx.get("http://test-comfyui:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( + return_value=httpx.Response( + 200, json={"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": {"outputs": {}}} + ) ) - result = await client.get_history_item("abc-123") - assert "abc-123" in result + result = await client.get_history_item("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result @respx.mock async def test_get_embeddings(self, client): diff --git a/tests/test_integration.py b/tests/test_integration.py index e9c8ee5..ed53e26 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -22,13 +22,15 @@ async def test_generate_image_lists_models_then_generates(self): return_value=httpx.Response(200, json=["sd_v15.safetensors"]) ) respx.post("http://mock-comfyui:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "test-001"}) + return_value=httpx.Response( + 200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + ) ) - respx.get("http://mock-comfyui:8188/history/test-001").mock( + respx.get("http://mock-comfyui:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( return_value=httpx.Response( 200, json={ - "test-001": { + "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": { "outputs": { "9": { "images": [ @@ -57,18 +59,20 @@ async def test_generate_image_lists_models_then_generates(self): # Step 2: Generate an image generate_fn = tools["generate_image"].fn result = await generate_fn(prompt="a sunset over mountains") - assert "test-001" in result + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result # Step 3: Check the job get_job_fn = tools["get_job"].fn - job = await get_job_fn(prompt_id="test-001") - assert "test-001" in job + job = await get_job_fn(prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in job @respx.mock async def test_run_workflow_with_dangerous_node_in_audit_mode(self): """Audit mode logs dangerous nodes but still submits the workflow.""" respx.post("http://mock-comfyui:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "danger-001"}) + return_value=httpx.Response( + 200, json={"prompt_id": "11111111-2222-3333-4444-555555555555"} + ) ) settings = Settings(comfyui=ComfyUISettings(url="http://mock-comfyui:8188")) @@ -77,7 +81,7 @@ async def test_run_workflow_with_dangerous_node_in_audit_mode(self): run_workflow_fn = server._tool_manager._tools["run_workflow"].fn workflow = json.dumps({"1": {"class_type": "Terminal", "inputs": {}}}) result = await run_workflow_fn(workflow=workflow) - assert "danger-001" in result + assert "11111111-2222-3333-4444-555555555555" in result assert "Terminal" in result async def test_run_workflow_blocked_in_enforce_mode(self): diff --git a/tests/test_progress.py b/tests/test_progress.py index c44361f..e3d25f8 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -47,8 +47,8 @@ async def __aexit__(self, *args): class TestProgressState: def test_default_state(self): - state = ProgressState(prompt_id="abc-123") - assert state.prompt_id == "abc-123" + state = ProgressState(prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + assert state.prompt_id == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" assert state.status == "unknown" assert state.step is None assert state.total_steps is None @@ -56,9 +56,11 @@ def test_default_state(self): assert state.outputs == [] def test_to_dict_omits_none_fields(self): - state = ProgressState(prompt_id="abc-123", status="queued", queue_position=3) + state = ProgressState( + prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", status="queued", queue_position=3 + ) d = state.to_dict() - assert d["prompt_id"] == "abc-123" + assert d["prompt_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" assert d["status"] == "queued" assert d["queue_position"] == 3 assert "step" not in d @@ -67,7 +69,7 @@ def test_to_dict_omits_none_fields(self): def test_to_dict_includes_set_fields(self): state = ProgressState( - prompt_id="abc-123", + prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", status="running", step=5, total_steps=20, @@ -118,7 +120,7 @@ def fake_connect(url, **kwargs): monkeypatch.setattr("comfyui_mcp.progress.websockets.connect", fake_connect) - state = await progress.wait_for_completion("prompt-1") + state = await progress.wait_for_completion("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "completed" assert state.step == 20 assert state.total_steps == 20 @@ -148,7 +150,7 @@ def fake_connect(url, **kwargs): monkeypatch.setattr("comfyui_mcp.progress.websockets.connect", fake_connect) - state = await progress.wait_for_completion("prompt-1") + state = await progress.wait_for_completion("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "error" async def test_wait_for_completion_timeout(self, monkeypatch): @@ -174,7 +176,7 @@ def fake_connect(url, **kwargs): monkeypatch.setattr("comfyui_mcp.progress.websockets.connect", fake_connect) - state = await progress.wait_for_completion("prompt-1") + state = await progress.wait_for_completion("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "timeout" @respx.mock @@ -188,6 +190,7 @@ def fail_connect(url, **kwargs): monkeypatch.setattr("comfyui_mcp.progress.websockets.connect", fail_connect) # Simulate: first poll returns "running", second returns "completed" + prompt_id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" history_call_count = 0 def history_side_effect(request): @@ -197,7 +200,7 @@ def history_side_effect(request): return httpx.Response( 200, json={ - "prompt-1": { + prompt_id: { "outputs": { "9": {"images": [{"filename": "out.png", "subfolder": ""}]} }, @@ -207,12 +210,12 @@ def history_side_effect(request): ) return httpx.Response(200, json={}) - respx.get("http://test:8188/history/prompt-1").mock(side_effect=history_side_effect) + respx.get(f"http://test:8188/history/{prompt_id}").mock(side_effect=history_side_effect) respx.get("http://test:8188/queue").mock( return_value=httpx.Response( 200, json={ - "queue_running": [["0", "prompt-1", {}, {}]], + "queue_running": [["0", prompt_id, {}, {}]], "queue_pending": [], }, ) @@ -221,7 +224,7 @@ def history_side_effect(request): # Patch sleep to avoid real delays in tests monkeypatch.setattr("comfyui_mcp.progress.asyncio.sleep", AsyncMock()) - state = await progress.wait_for_completion("prompt-1") + state = await progress.wait_for_completion(prompt_id) assert state.status == "completed" assert state.elapsed_seconds is not None @@ -230,11 +233,11 @@ async def test_get_state_http_fallback_completed(self): client = ComfyUIClient(base_url="http://test:8188") progress = WebSocketProgress(client, timeout=10.0) - respx.get("http://test:8188/history/abc-123").mock( + respx.get("http://test:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( return_value=httpx.Response( 200, json={ - "abc-123": { + "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": { "outputs": { "9": {"images": [{"filename": "img.png", "subfolder": "output"}]} }, @@ -253,7 +256,7 @@ async def test_get_state_http_fallback_completed(self): ) ) - state = await progress.get_state("abc-123") + state = await progress.get_state("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "completed" assert len(state.outputs) == 1 assert state.outputs[0]["filename"] == "img.png" @@ -263,7 +266,7 @@ async def test_get_state_http_fallback_queued(self): client = ComfyUIClient(base_url="http://test:8188") progress = WebSocketProgress(client, timeout=10.0) - respx.get("http://test:8188/history/abc-123").mock( + respx.get("http://test:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( return_value=httpx.Response(200, json={}) ) respx.get("http://test:8188/queue").mock( @@ -273,13 +276,13 @@ async def test_get_state_http_fallback_queued(self): "queue_running": [], "queue_pending": [ [0, "other-id", {}, {}], - [1, "abc-123", {}, {}], + [1, "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", {}, {}], ], }, ) ) - state = await progress.get_state("abc-123") + state = await progress.get_state("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "queued" assert state.queue_position == 2 @@ -288,18 +291,18 @@ async def test_get_state_http_fallback_running(self): client = ComfyUIClient(base_url="http://test:8188") progress = WebSocketProgress(client, timeout=10.0) - respx.get("http://test:8188/history/abc-123").mock( + respx.get("http://test:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( return_value=httpx.Response(200, json={}) ) respx.get("http://test:8188/queue").mock( return_value=httpx.Response( 200, json={ - "queue_running": [[0, "abc-123", {}, {}]], + "queue_running": [[0, "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", {}, {}]], "queue_pending": [], }, ) ) - state = await progress.get_state("abc-123") + state = await progress.get_state("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert state.status == "running" diff --git a/tests/test_tools_generation.py b/tests/test_tools_generation.py index 77a0f8d..a612751 100644 --- a/tests/test_tools_generation.py +++ b/tests/test_tools_generation.py @@ -75,25 +75,29 @@ class TestRunWorkflow: async def test_submits_workflow(self, components): client, audit, limiter, inspector = components respx.post("http://test:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "abc-123"}) + return_value=httpx.Response( + 200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + ) ) mcp = FastMCP("test") tools = register_generation_tools(mcp, client, audit, limiter, inspector) workflow = {"1": {"class_type": "KSampler", "inputs": {}}} result = await tools["run_workflow"](workflow=json.dumps(workflow)) - assert "abc-123" in result + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result @respx.mock async def test_audit_mode_logs_dangerous_nodes(self, components): client, audit, limiter, inspector = components respx.post("http://test:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "abc-123"}) + return_value=httpx.Response( + 200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + ) ) mcp = FastMCP("test") tools = register_generation_tools(mcp, client, audit, limiter, inspector) workflow = {"1": {"class_type": "EvalNode", "inputs": {}}} result = await tools["run_workflow"](workflow=json.dumps(workflow)) - assert "abc-123" in result + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result assert "EvalNode" in result async def test_enforce_mode_blocks_unapproved(self, enforce_components): @@ -812,7 +816,9 @@ async def test_run_workflow_warns_missing_model(self, components): return_value=httpx.Response(200, json=["other_model.safetensors"]) ) respx.post("http://test:8188/prompt").mock( - return_value=httpx.Response(200, json={"prompt_id": "abc-123"}) + return_value=httpx.Response( + 200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + ) ) mcp_server = FastMCP("test") model_checker = ModelChecker() diff --git a/tests/test_tools_jobs.py b/tests/test_tools_jobs.py index 816e194..402a866 100644 --- a/tests/test_tools_jobs.py +++ b/tests/test_tools_jobs.py @@ -52,7 +52,7 @@ async def test_cancel_job_sends_delete(self, components): route = respx.post("http://test:8188/queue").mock(return_value=httpx.Response(200, json={})) mcp = FastMCP("test") tools = register_job_tools(mcp, client, audit, limiter) - await tools["cancel_job"](prompt_id="abc-123") + await tools["cancel_job"](prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") assert route.called @@ -73,13 +73,16 @@ class TestGetJob: @respx.mock async def test_get_job_returns_history_item(self, components): client, audit, limiter = components - respx.get("http://test:8188/history/abc-123").mock( - return_value=httpx.Response(200, json={"abc-123": {"outputs": {"9": {"images": []}}}}) + respx.get("http://test:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( + return_value=httpx.Response( + 200, + json={"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": {"outputs": {"9": {"images": []}}}}, + ) ) mcp = FastMCP("test") tools = register_job_tools(mcp, client, audit, limiter) - result = await tools["get_job"](prompt_id="abc-123") - assert "abc-123" in result + result = await tools["get_job"](prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result class TestGetQueueStatus: @@ -111,11 +114,11 @@ class TestGetProgress: @respx.mock async def test_returns_completed_state(self, progress_components): client, audit, limiter, read_limiter, progress = progress_components - respx.get("http://test:8188/history/abc-123").mock( + respx.get("http://test:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock( return_value=httpx.Response( 200, json={ - "abc-123": { + "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": { "outputs": { "9": {"images": [{"filename": "out.png", "subfolder": "output"}]} }, @@ -142,16 +145,19 @@ async def test_returns_completed_state(self, progress_components): read_limiter=read_limiter, progress=progress, ) - result = await tools["get_progress"](prompt_id="abc-123") + result = await tools["get_progress"](prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") data = json.loads(result) assert data["status"] == "completed" - assert data["prompt_id"] == "abc-123" + assert data["prompt_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" assert len(data["outputs"]) == 1 @respx.mock async def test_returns_unknown_when_not_found(self, progress_components): client, audit, limiter, read_limiter, progress = progress_components - respx.get("http://test:8188/history/nope").mock(return_value=httpx.Response(200, json={})) + not_found_id = "11111111-2222-3333-4444-555555555555" + respx.get(f"http://test:8188/history/{not_found_id}").mock( + return_value=httpx.Response(200, json={}) + ) respx.get("http://test:8188/queue").mock( return_value=httpx.Response( 200, @@ -170,6 +176,6 @@ async def test_returns_unknown_when_not_found(self, progress_components): read_limiter=read_limiter, progress=progress, ) - result = await tools["get_progress"](prompt_id="nope") + result = await tools["get_progress"](prompt_id=not_found_id) data = json.loads(result) assert data["status"] == "unknown"