Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 55 additions & 7 deletions src/comfyui_mcp/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@

from __future__ import annotations

import asyncio
import logging
import threading
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

from pydantic import BaseModel, Field, model_serializer

_logger = logging.getLogger(__name__)

_SENSITIVE_KEYS = {"token", "password", "secret", "api_key", "authorization"}


Expand Down Expand Up @@ -51,16 +57,58 @@ def serialize(self) -> dict[str, object]:
class AuditLogger:
def __init__(self, audit_file: Path) -> None:
self._audit_file = Path(audit_file)
self._dir_created = False
self._lock = threading.Lock()

def log(self, *, tool: str, action: str, **kwargs) -> AuditRecord:
"""Write an audit record as a JSON line."""
record = AuditRecord(tool=tool, action=action, **kwargs)
def _is_path_safe(self) -> bool:
"""Check that neither the audit file nor any parent is a symlink.

Uses is_symlink() which detects both live and dangling symlinks
(unlike exists() which returns False for dangling symlinks).
"""
if self._audit_file.is_symlink():
return False
return all(not parent.is_symlink() for parent in self._audit_file.parents)

def _ensure_dir(self) -> bool:
"""Create parent directories once. Returns False on failure."""
if self._dir_created:
return True
try:
self._audit_file.parent.mkdir(parents=True, exist_ok=True)
with open(self._audit_file, "a") as f:
f.write(record.model_dump_json() + "\n")
self._dir_created = True
return True
except OSError as e:
Comment on lines +73 to +81
import logging
_logger.error("AUDIT LOG FAILURE: cannot create directory: %s", e)
return False

def _write_record(self, record: AuditRecord) -> None:
"""Synchronous, thread-safe write of a single audit record."""
with self._lock:
# Check symlink safety on every write (not cached) to detect
# post-init symlink swaps on the file or any parent directory
if not self._is_path_safe():
_logger.error(
"AUDIT LOG REFUSED: path contains symlink: %s",
self._audit_file,
)
return
if not self._ensure_dir():
return
try:
with open(self._audit_file, "a") as f:
f.write(record.model_dump_json() + "\n")
except OSError as e:
_logger.error("AUDIT LOG FAILURE: %s", e)

logging.getLogger(__name__).error("AUDIT LOG FAILURE: %s", e)
def log(self, *, tool: str, action: str, **kwargs: Any) -> AuditRecord:
"""Write an audit record as a JSON line (synchronous)."""
record = AuditRecord(tool=tool, action=action, **kwargs)
self._write_record(record)
return record

async def async_log(self, *, tool: str, action: str, **kwargs: Any) -> AuditRecord:
"""Write an audit record without blocking the event loop."""
record = AuditRecord(tool=tool, action=action, **kwargs)
await asyncio.to_thread(self._write_record, record)
Comment on lines +85 to +113
return record
30 changes: 27 additions & 3 deletions src/comfyui_mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

from __future__ import annotations

import atexit
import contextlib
from pathlib import Path

import httpx
from mcp.server.fastmcp import FastMCP

from comfyui_mcp.audit import AuditLogger
Expand Down Expand Up @@ -83,6 +86,7 @@ def _register_all_tools(
download_validator: DownloadValidator,
model_checker: ModelChecker,
model_search_settings: ModelSearchSettings,
search_http: httpx.AsyncClient,
) -> None:
"""Register all MCP tool groups with their dependencies."""
register_discovery_tools(server, client, audit, rate_limiters["read"], sanitizer, node_auditor)
Expand Down Expand Up @@ -118,10 +122,13 @@ def _register_all_tools(
detector=detector,
validator=download_validator,
search_settings=model_search_settings,
search_http=search_http,
)


def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
def _build_server(
settings: Settings | None = None,
) -> tuple[FastMCP, Settings, ComfyUIClient, httpx.AsyncClient]:
"""Build and configure the MCP server with all tools registered."""
if settings is None:
settings = load_settings()
Expand All @@ -144,6 +151,7 @@ def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
allowed_extensions=settings.security.allowed_model_extensions,
)
model_checker = ModelChecker()
search_http = httpx.AsyncClient(timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10))

server_kwargs: dict = {
"name": "ComfyUI",
Expand Down Expand Up @@ -183,13 +191,29 @@ def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
download_validator=download_validator,
model_checker=model_checker,
model_search_settings=settings.model_search,
search_http=search_http,
)

return server, settings
return server, settings, client, search_http


# Module-level server instance for import and CLI use
mcp, _settings = _build_server()
mcp, _settings, _client, _search_http = _build_server()


def _cleanup() -> None:
"""Best-effort cleanup of HTTP clients on process exit."""
import asyncio

async def _close() -> None:
await _client.close()
await _search_http.aclose()

with contextlib.suppress(Exception):
asyncio.run(_close())


atexit.register(_cleanup)


def main() -> None:
Expand Down
28 changes: 15 additions & 13 deletions src/comfyui_mcp/tools/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def list_models(folder: str = "checkpoints") -> list[str]:
"""List available models in a folder (checkpoints, loras, vae, etc.)."""
limiter.check("list_models")
sanitizer.validate_path_segment(folder, label="folder")
audit.log(tool="list_models", action="called", extra={"folder": folder})
await audit.async_log(tool="list_models", action="called", extra={"folder": folder})
return await client.get_models(folder)

tool_fns["list_models"] = list_models
Expand All @@ -168,7 +168,7 @@ async def list_models(folder: str = "checkpoints") -> list[str]:
async def list_nodes() -> list[str]:
"""List all available ComfyUI node types."""
limiter.check("list_nodes")
audit.log(tool="list_nodes", action="called")
await audit.async_log(tool="list_nodes", action="called")
info = await client.get_object_info()
return sorted(info.keys())

Expand All @@ -178,7 +178,9 @@ async def list_nodes() -> list[str]:
async def get_node_info(node_class: str) -> dict:
"""Get detailed information about a specific node type."""
limiter.check("get_node_info")
audit.log(tool="get_node_info", action="called", extra={"node_class": node_class})
await audit.async_log(
tool="get_node_info", action="called", extra={"node_class": node_class}
)
return await client.get_object_info(node_class)

tool_fns["get_node_info"] = get_node_info
Expand All @@ -187,7 +189,7 @@ async def get_node_info(node_class: str) -> dict:
async def list_workflows() -> list:
"""List available workflow templates."""
limiter.check("list_workflows")
audit.log(tool="list_workflows", action="called")
await audit.async_log(tool="list_workflows", action="called")
return await client.get_workflow_templates()

tool_fns["list_workflows"] = list_workflows
Expand All @@ -196,7 +198,7 @@ async def list_workflows() -> list:
async def list_extensions() -> list:
"""List available ComfyUI extensions."""
limiter.check("list_extensions")
audit.log(tool="list_extensions", action="called")
await audit.async_log(tool="list_extensions", action="called")
return await client.get_extensions()

tool_fns["list_extensions"] = list_extensions
Expand All @@ -205,7 +207,7 @@ async def list_extensions() -> list:
async def get_server_features() -> dict:
"""Get ComfyUI server features and capabilities."""
limiter.check("get_server_features")
audit.log(tool="get_server_features", action="called")
await audit.async_log(tool="get_server_features", action="called")
return await client.get_features()

tool_fns["get_server_features"] = get_server_features
Expand All @@ -214,7 +216,7 @@ async def get_server_features() -> dict:
async def list_model_folders() -> list[str]:
"""List available model folder types (checkpoints, loras, vae, etc.)."""
limiter.check("list_model_folders")
audit.log(tool="list_model_folders", action="called")
await audit.async_log(tool="list_model_folders", action="called")
return await client.get_model_types()

tool_fns["list_model_folders"] = list_model_folders
Expand All @@ -230,7 +232,7 @@ async def get_model_metadata(folder: str, filename: str) -> dict:
limiter.check("get_model_metadata")
sanitizer.validate_path_segment(folder, label="folder")
sanitizer.validate_path_segment(filename, label="filename")
audit.log(
await audit.async_log(
tool="get_model_metadata",
action="called",
extra={"folder": folder, "filename": filename},
Expand All @@ -250,7 +252,7 @@ async def audit_dangerous_nodes() -> dict:
Dictionary with dangerous and suspicious node counts and lists
"""
limiter.check("audit_dangerous_nodes")
audit.log(tool="audit_dangerous_nodes", action="started")
await audit.async_log(tool="audit_dangerous_nodes", action="started")

auditor = node_auditor if node_auditor else NodeAuditor()

Expand All @@ -273,7 +275,7 @@ async def audit_dangerous_nodes() -> dict:
},
}

audit.log(
await audit.async_log(
tool="audit_dangerous_nodes",
action="completed",
extra={
Expand All @@ -300,7 +302,7 @@ async def get_system_info() -> dict:
queue (running/pending counts).
"""
limiter.check("get_system_info")
audit.log(tool="get_system_info", action="called")
await audit.async_log(tool="get_system_info", action="called")

raw = await client.get_system_stats()
queue_raw = await client.get_queue()
Expand Down Expand Up @@ -353,7 +355,7 @@ async def get_model_presets(
Dictionary containing normalized family and recommended settings.
"""
limiter.check("get_model_presets")
audit.log(
await audit.async_log(
tool="get_model_presets",
action="called",
extra={"model_name": model_name, "model_family": model_family},
Expand Down Expand Up @@ -389,7 +391,7 @@ async def get_prompting_guide(model_family: str) -> dict[str, Any]:
"""
limiter.check("get_prompting_guide")
normalized = _normalize_model_family(model_family)
audit.log(
await audit.async_log(
tool="get_prompting_guide",
action="called",
extra={"model_family": normalized},
Expand Down
18 changes: 10 additions & 8 deletions src/comfyui_mcp/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ async def upload_image(filename: str, image_data: str, subfolder: str = "") -> s
clean_subfolder = sanitizer.validate_subfolder(subfolder)
raw = base64.b64decode(image_data)
sanitizer.validate_size(len(raw))
audit.log(
await audit.async_log(
tool="upload_image",
action="uploading",
extra={"filename": clean_name, "size_bytes": len(raw)},
)
result = await client.upload_image(raw, clean_name, clean_subfolder)
audit.log(tool="upload_image", action="uploaded", extra={"result": result})
await audit.async_log(tool="upload_image", action="uploaded", extra={"result": result})
return f"Uploaded {result.get('name', clean_name)} to ComfyUI input directory"

tool_fns["upload_image"] = upload_image
Expand All @@ -128,7 +128,9 @@ async def get_image(filename: str, subfolder: str = "output") -> str:
limiter.check("get_image")
clean_name = sanitizer.validate_filename(filename)
clean_subfolder = sanitizer.validate_subfolder(subfolder)
audit.log(tool="get_image", action="downloading", extra={"filename": clean_name})
await audit.async_log(
tool="get_image", action="downloading", extra={"filename": clean_name}
)
data, content_type = await client.get_image(clean_name, clean_subfolder)
b64 = base64.b64encode(data).decode()
return f"data:{content_type};base64,{b64}"
Expand All @@ -139,7 +141,7 @@ async def get_image(filename: str, subfolder: str = "output") -> str:
async def list_outputs() -> list[str]:
"""List files in ComfyUI's output directory."""
limiter.check("list_outputs")
audit.log(tool="list_outputs", action="called")
await audit.async_log(tool="list_outputs", action="called")
history = await client.get_history()
filenames = set()
for entry in history.values():
Expand Down Expand Up @@ -167,13 +169,13 @@ async def upload_mask(filename: str, mask_data: str, subfolder: str = "") -> str
clean_subfolder = sanitizer.validate_subfolder(subfolder)
raw = base64.b64decode(mask_data)
sanitizer.validate_size(len(raw))
audit.log(
await audit.async_log(
tool="upload_mask",
action="uploading",
extra={"filename": clean_name, "size_bytes": len(raw)},
)
result = await client.upload_mask(raw, clean_name, clean_subfolder)
audit.log(tool="upload_mask", action="uploaded", extra={"result": result})
await audit.async_log(tool="upload_mask", action="uploaded", extra={"result": result})
return f"Uploaded mask {result.get('name', clean_name)} to ComfyUI input directory"

tool_fns["upload_mask"] = upload_mask
Expand All @@ -197,7 +199,7 @@ async def get_workflow_from_image(filename: str, subfolder: str = "output") -> d
limiter.check("get_workflow_from_image")
clean_name = sanitizer.validate_filename(filename)
clean_subfolder = sanitizer.validate_subfolder(subfolder)
audit.log(
await audit.async_log(
tool="get_workflow_from_image",
action="extracting",
extra={"filename": clean_name, "subfolder": clean_subfolder},
Expand Down Expand Up @@ -235,7 +237,7 @@ async def get_workflow_from_image(filename: str, subfolder: str = "output") -> d
else:
message = "No workflow metadata found in this image"

audit.log(
await audit.async_log(
tool="get_workflow_from_image",
action="extracted",
extra={
Expand Down
Loading
Loading