Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions src/comfyui_mcp/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from __future__ import annotations

import logging
from datetime import UTC, datetime
from pathlib import Path

from pydantic import BaseModel, Field, model_serializer

_SENSITIVE_KEYS = {"token", "password", "secret", "api_key", "authorization"}

logger = logging.getLogger(__name__)


def _redact_sensitive(data: dict[str, object]) -> dict[str, object]:
"""Remove sensitive keys from a dictionary."""
Expand Down Expand Up @@ -51,16 +54,31 @@ def serialize(self) -> dict[str, object]:
class AuditLogger:
def __init__(self, audit_file: Path) -> None:
self._audit_file = Path(audit_file)
self._dir_created = False

def _ensure_directory(self) -> None:
"""Create parent directories on first write, with symlink check."""
if self._dir_created:
return
parent = self._audit_file.parent
parent.mkdir(parents=True, exist_ok=True)
# Reject symlinked audit file — could redirect entries to attacker-controlled path
if self._audit_file.exists() and self._audit_file.is_symlink():
raise OSError(f"Audit log path is a symlink — refusing to write: {self._audit_file}")
Comment on lines +65 to +67
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The symlink check is currently gated by self._audit_file.exists(). A dangling symlink returns exists() == False but is_symlink() == True, so this would not be rejected and open(..., 'a') would follow the symlink target. Consider checking is_symlink() unconditionally (no exists() guard), ideally immediately before each open to reduce TOCTOU risk.

Copilot uses AI. Check for mistakes.
self._dir_created = True
Comment on lines +61 to +68
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_ensure_directory() returns early once _dir_created is set, which also skips the symlink safety check on subsequent writes. That allows the audit file to be swapped to a symlink after the first successful (or failed) write without being detected. Consider caching only the directory creation work, but re-checking symlink/hardlink safety on every log() call (or using an os.open-based approach with O_NOFOLLOW on supported platforms).

Suggested change
if self._dir_created:
return
parent = self._audit_file.parent
parent.mkdir(parents=True, exist_ok=True)
# Reject symlinked audit file — could redirect entries to attacker-controlled path
if self._audit_file.exists() and self._audit_file.is_symlink():
raise OSError(f"Audit log path is a symlink — refusing to write: {self._audit_file}")
self._dir_created = True
parent = self._audit_file.parent
if not self._dir_created:
parent.mkdir(parents=True, exist_ok=True)
self._dir_created = True
# Reject symlinked audit file — could redirect entries to attacker-controlled path
if self._audit_file.exists() and self._audit_file.is_symlink():
raise OSError(f"Audit log path is a symlink — refusing to write: {self._audit_file}")

Copilot uses AI. Check for mistakes.

def log(self, *, tool: str, action: str, **kwargs) -> AuditRecord:
"""Write an audit record as a JSON line."""
"""Write an audit record as a JSON line.

Raises OSError on write failure — audit log integrity is a security
requirement, so failures must not be silently swallowed.
"""
record = AuditRecord(tool=tool, action=action, **kwargs)
try:
self._audit_file.parent.mkdir(parents=True, exist_ok=True)
self._ensure_directory()
with open(self._audit_file, "a") as f:
f.write(record.model_dump_json() + "\n")
except OSError as e:
import logging

logging.getLogger(__name__).error("AUDIT LOG FAILURE: %s", e)
except OSError:
logger.exception("AUDIT LOG FAILURE — cannot write to %s", self._audit_file)
raise
return record
13 changes: 13 additions & 0 deletions src/comfyui_mcp/tools/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
# Extensions we look for when finding the primary model file in HuggingFace repos
_MODEL_EXTENSIONS = {".safetensors", ".ckpt", ".pt", ".pth", ".bin"}

# Input validation limits for external API queries
_MAX_QUERY_LENGTH = 200
_MAX_MODEL_TYPE_LENGTH = 50


async def _search_civitai(
query: str,
Expand Down Expand Up @@ -173,6 +177,15 @@ async def search_models(
if source not in ("civitai", "huggingface"):
raise ValueError("source must be 'civitai' or 'huggingface'")

if not query or not query.strip():
raise ValueError("query must not be empty")
if len(query) > _MAX_QUERY_LENGTH:
raise ValueError(f"query too long ({len(query)} chars, max {_MAX_QUERY_LENGTH})")
if model_type and len(model_type) > _MAX_MODEL_TYPE_LENGTH:
raise ValueError(
f"model_type too long ({len(model_type)} chars, max {_MAX_MODEL_TYPE_LENGTH})"
)

cap = max(1, min(limit, search_settings.max_search_results))

audit.log(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_audit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Tests for structured audit logging."""

import json
import os

import pytest

from comfyui_mcp.audit import AuditLogger, AuditRecord

Expand Down Expand Up @@ -81,3 +84,22 @@ def test_log_strips_sensitive_keys_from_extra(self, tmp_path):
content = log_file.read_text()
assert "secret-value" not in content
assert "a cat" in content

def test_log_raises_on_write_failure(self, tmp_path):
log_file = tmp_path / "readonly" / "audit.log"
(tmp_path / "readonly").mkdir()
(tmp_path / "readonly").chmod(0o444)
logger = AuditLogger(audit_file=log_file)
with pytest.raises(OSError):
logger.log(tool="test", action="called")
# Restore permissions for cleanup
(tmp_path / "readonly").chmod(0o755)
Comment on lines +90 to +96
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permission restoration for the readonly directory only runs if the expected exception is raised. To avoid leaving tmp_path/readonly non-traversable (which can break pytest cleanup) when this test fails unexpectedly on some environments, wrap the chmod change in a try/finally to always restore permissions.

Suggested change
(tmp_path / "readonly").mkdir()
(tmp_path / "readonly").chmod(0o444)
logger = AuditLogger(audit_file=log_file)
with pytest.raises(OSError):
logger.log(tool="test", action="called")
# Restore permissions for cleanup
(tmp_path / "readonly").chmod(0o755)
readonly_dir = tmp_path / "readonly"
readonly_dir.mkdir()
try:
readonly_dir.chmod(0o444)
logger = AuditLogger(audit_file=log_file)
with pytest.raises(OSError):
logger.log(tool="test", action="called")
finally:
# Restore permissions for cleanup
readonly_dir.chmod(0o755)

Copilot uses AI. Check for mistakes.

def test_log_rejects_symlinked_path(self, tmp_path):
real_file = tmp_path / "real.log"
real_file.touch()
link = tmp_path / "audit.log"
os.symlink(real_file, link)
logger = AuditLogger(audit_file=link)
with pytest.raises(OSError, match="symlink"):
logger.log(tool="test", action="called")
Comment on lines +99 to +105
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new symlink protection is only tested for a symlink pointing to an existing target. Consider adding coverage for a dangling symlink (target missing) and for replacing the audit log with a symlink after an initial successful write, to prevent regressions in the symlink-hardening logic.

Copilot generated this review using guidance from repository custom instructions.
18 changes: 18 additions & 0 deletions tests/test_tools_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ async def test_search_invalid_source(self, registered_tools):
with pytest.raises(ValueError, match="source must be"):
await registered_tools["search_models"](query="test", source="invalid")

async def test_search_empty_query_rejected(self, registered_tools):
with pytest.raises(ValueError, match="must not be empty"):
await registered_tools["search_models"](query="", source="civitai")

async def test_search_whitespace_query_rejected(self, registered_tools):
with pytest.raises(ValueError, match="must not be empty"):
await registered_tools["search_models"](query=" ", source="civitai")

async def test_search_query_too_long_rejected(self, registered_tools):
with pytest.raises(ValueError, match="query too long"):
await registered_tools["search_models"](query="x" * 201, source="civitai")

async def test_search_model_type_too_long_rejected(self, registered_tools):
with pytest.raises(ValueError, match="model_type too long"):
await registered_tools["search_models"](
query="test", source="civitai", model_type="x" * 51
)

@respx.mock
async def test_search_with_api_key(self, components):
components["search_settings"] = ModelSearchSettings(civitai_api_key="test_key")
Expand Down
Loading