Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4

- name: Set up Python
run: uv python install 3.12
Expand All @@ -30,10 +30,10 @@ jobs:
type-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4

- name: Set up Python
run: uv python install 3.12
Expand All @@ -47,10 +47,10 @@ jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Install uv
uses: astral-sh/setup-uv@v4
uses: astral-sh/setup-uv@e4db8464a088ece1b920f60402e813ea4de65b8f # v4

- name: Set up Python
run: uv python install 3.12
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ jobs:

steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4

- name: Set up QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@c7c53464625b32c7a7e944ae62b3e17d2b600130 # v3

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3

- name: Log in to Container Registry
uses: docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
with:
images: ghcr.io/${{ github.repository }}
tags: |
Expand All @@ -46,7 +46,7 @@ jobs:
type=sha,prefix=

- name: Build and push
uses: docker/build-push-action@v5
uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5
with:
context: .
push: true
Expand Down
8 changes: 6 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ FROM python:3.12-slim

WORKDIR /app

RUN groupadd --system app && useradd --system --gid app app

RUN pip install --no-cache-dir uv

COPY pyproject.toml uv.lock README.md ./
COPY --chown=app:app pyproject.toml uv.lock README.md ./
RUN uv sync --frozen --no-dev

COPY src/ ./src/
COPY --chown=app:app src/ ./src/

USER app

ENV PYTHONPATH=/app

Expand Down
31 changes: 30 additions & 1 deletion src/comfyui_mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,30 @@
from __future__ import annotations

import asyncio
import re

import httpx

_ALLOWED_HTTP_METHODS = frozenset({"GET", "POST", "PUT", "DELETE", "PATCH"})
_UUID_RE = re.compile(r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$")
_SAFE_SEGMENT_RE = re.compile(r"^[A-Za-z0-9_.\-]+$")


def _validate_prompt_id(prompt_id: str) -> str:
"""Validate that a prompt_id is a well-formed UUID."""
if not _UUID_RE.match(prompt_id):
raise ValueError(f"Invalid prompt_id format: {prompt_id!r}")
return prompt_id


def _validate_path_segment(value: str, *, label: str = "value") -> str:
"""Validate that a value is safe for interpolation into a URL path segment."""
if not value:
raise ValueError(f"{label} must not be empty")
if not _SAFE_SEGMENT_RE.match(value):
raise ValueError(f"{label} contains invalid characters: {value!r}")
return value


class ComfyUIClient:
def __init__(
Expand Down Expand Up @@ -38,11 +59,14 @@ async def _get_client(self) -> httpx.AsyncClient:

async def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
"""Make an HTTP request with retry logic for transient failures."""
normalized = method.upper()
if normalized not in _ALLOWED_HTTP_METHODS:
raise ValueError(f"HTTP method not allowed: {method!r}")
last_exception: Exception | None = None
for attempt in range(self._max_retries):
try:
c = await self._get_client()
r = await getattr(c, method)(path, **kwargs)
r = await c.request(normalized, path, **kwargs)
r.raise_for_status()
return r
except httpx.HTTPStatusError:
Expand Down Expand Up @@ -88,6 +112,8 @@ async def get_models(self, folder: str) -> list:
return r.json()

async def get_object_info(self, node_class: str | None = None) -> dict:
if node_class is not None:
_validate_path_segment(node_class, label="node_class")
path = f"/object_info/{node_class}" if node_class else "/object_info"
r = await self._request("get", path)
return r.json()
Expand All @@ -97,13 +123,15 @@ async def get_history(self) -> dict:
return r.json()

async def get_history_item(self, prompt_id: str) -> dict:
_validate_prompt_id(prompt_id)
r = await self._request("get", f"/history/{prompt_id}")
return r.json()

async def interrupt(self) -> None:
await self._request("post", "/interrupt")

async def delete_queue_item(self, prompt_id: str) -> None:
_validate_prompt_id(prompt_id)
await self._request("post", "/queue", json={"delete": [prompt_id]})

async def upload_image(self, data: bytes, filename: str, subfolder: str = "") -> dict:
Expand Down Expand Up @@ -228,6 +256,7 @@ async def get_download_tasks(self) -> list[dict]:

async def delete_download_task(self, task_id: str) -> dict:
"""DELETE /model-manager/download/{task_id} — cancel and remove a download."""
_validate_path_segment(task_id, label="task_id")
r = await self._request("delete", f"/model-manager/download/{task_id}")
payload = self._unwrap_model_manager_response(r.json())
if isinstance(payload, dict):
Expand Down
13 changes: 13 additions & 0 deletions src/comfyui_mcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,19 @@ class SSESettings(BaseModel):
host: str = "127.0.0.1"
port: int = 8080

@field_validator("host")
@classmethod
def warn_non_localhost(cls, v: str) -> str:
if v not in ("127.0.0.1", "::1", "localhost"):
import logging

logging.getLogger("comfyui_mcp.config").warning(
"SSE transport binding to %s — this exposes all MCP tools without "
"authentication. Only use behind a reverse proxy with auth and TLS.",
v,
)
return v


class TransportSettings(BaseModel):
sse: SSESettings = SSESettings()
Expand Down
2 changes: 0 additions & 2 deletions src/comfyui_mcp/workflow/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,11 @@ async def validate_workflow(
inputs = node_data.get("inputs")
if inputs is None:
errors.append(f"Node '{node_id}': missing 'inputs'")
node_data["inputs"] = {}
continue
if not isinstance(inputs, dict):
errors.append(
f"Node '{node_id}': 'inputs' must be an object, got {type(inputs).__name__}"
)
node_data["inputs"] = {}
continue
for input_name, value in inputs.items():
if isinstance(value, list) and len(value) == 2 and isinstance(value[0], str):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_record_serializes_to_json(self):
record = AuditRecord(
tool="run_workflow",
action="submitted",
prompt_id="abc-123",
prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
nodes_used=["KSampler", "CLIPTextEncode"],
warnings=["Dangerous node: EvalNode"],
)
Expand Down
18 changes: 11 additions & 7 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ async def test_get_queue(self, client):
@respx.mock
async def test_post_prompt(self, client):
respx.post("http://test-comfyui:8188/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "abc-123"})
return_value=httpx.Response(
200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"}
)
)
result = await client.post_prompt({"1": {"class_type": "KSampler", "inputs": {}}})
assert result["prompt_id"] == "abc-123"
assert result["prompt_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"

@respx.mock
async def test_get_models(self, client):
Expand Down Expand Up @@ -91,15 +93,17 @@ async def test_get_image(self, client):
@respx.mock
async def test_delete_queue_item(self, client):
respx.post("http://test-comfyui:8188/queue").mock(return_value=httpx.Response(200, json={}))
await client.delete_queue_item("abc-123")
await client.delete_queue_item("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")

@respx.mock
async def test_get_history_item(self, client):
respx.get("http://test-comfyui:8188/history/abc-123").mock(
return_value=httpx.Response(200, json={"abc-123": {"outputs": {}}})
respx.get("http://test-comfyui:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock(
return_value=httpx.Response(
200, json={"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": {"outputs": {}}}
)
)
result = await client.get_history_item("abc-123")
assert "abc-123" in result
result = await client.get_history_item("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result

@respx.mock
async def test_get_embeddings(self, client):
Expand Down
20 changes: 12 additions & 8 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ async def test_generate_image_lists_models_then_generates(self):
return_value=httpx.Response(200, json=["sd_v15.safetensors"])
)
respx.post("http://mock-comfyui:8188/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "test-001"})
return_value=httpx.Response(
200, json={"prompt_id": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"}
)
)
respx.get("http://mock-comfyui:8188/history/test-001").mock(
respx.get("http://mock-comfyui:8188/history/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee").mock(
return_value=httpx.Response(
200,
json={
"test-001": {
"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee": {
"outputs": {
"9": {
"images": [
Expand Down Expand Up @@ -57,18 +59,20 @@ async def test_generate_image_lists_models_then_generates(self):
# Step 2: Generate an image
generate_fn = tools["generate_image"].fn
result = await generate_fn(prompt="a sunset over mountains")
assert "test-001" in result
assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in result

# Step 3: Check the job
get_job_fn = tools["get_job"].fn
job = await get_job_fn(prompt_id="test-001")
assert "test-001" in job
job = await get_job_fn(prompt_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
assert "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" in job

@respx.mock
async def test_run_workflow_with_dangerous_node_in_audit_mode(self):
"""Audit mode logs dangerous nodes but still submits the workflow."""
respx.post("http://mock-comfyui:8188/prompt").mock(
return_value=httpx.Response(200, json={"prompt_id": "danger-001"})
return_value=httpx.Response(
200, json={"prompt_id": "11111111-2222-3333-4444-555555555555"}
)
)

settings = Settings(comfyui=ComfyUISettings(url="http://mock-comfyui:8188"))
Expand All @@ -77,7 +81,7 @@ async def test_run_workflow_with_dangerous_node_in_audit_mode(self):
run_workflow_fn = server._tool_manager._tools["run_workflow"].fn
workflow = json.dumps({"1": {"class_type": "Terminal", "inputs": {}}})
result = await run_workflow_fn(workflow=workflow)
assert "danger-001" in result
assert "11111111-2222-3333-4444-555555555555" in result
assert "Terminal" in result

async def test_run_workflow_blocked_in_enforce_mode(self):
Expand Down
Loading
Loading