Skip to content

Commit 54b058a

Browse files
authored
Merge pull request #30 from hybridindie/fix/p1-security-performance
fix: address P1 security and performance findings
2 parents a6d62b3 + 48a24a3 commit 54b058a

File tree

15 files changed

+461
-150
lines changed

15 files changed

+461
-150
lines changed

src/comfyui_mcp/audit.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@
22

33
from __future__ import annotations
44

5+
import asyncio
6+
import logging
7+
import threading
58
from datetime import UTC, datetime
69
from pathlib import Path
10+
from typing import Any
711

812
from pydantic import BaseModel, Field, model_serializer
913

14+
_logger = logging.getLogger(__name__)
15+
1016
_SENSITIVE_KEYS = {"token", "password", "secret", "api_key", "authorization"}
1117

1218

@@ -51,16 +57,58 @@ def serialize(self) -> dict[str, object]:
5157
class AuditLogger:
5258
def __init__(self, audit_file: Path) -> None:
5359
self._audit_file = Path(audit_file)
60+
self._dir_created = False
61+
self._lock = threading.Lock()
5462

55-
def log(self, *, tool: str, action: str, **kwargs) -> AuditRecord:
56-
"""Write an audit record as a JSON line."""
57-
record = AuditRecord(tool=tool, action=action, **kwargs)
63+
def _is_path_safe(self) -> bool:
64+
"""Check that neither the audit file nor any parent is a symlink.
65+
66+
Uses is_symlink() which detects both live and dangling symlinks
67+
(unlike exists() which returns False for dangling symlinks).
68+
"""
69+
if self._audit_file.is_symlink():
70+
return False
71+
return all(not parent.is_symlink() for parent in self._audit_file.parents)
72+
73+
def _ensure_dir(self) -> bool:
74+
"""Create parent directories once. Returns False on failure."""
75+
if self._dir_created:
76+
return True
5877
try:
5978
self._audit_file.parent.mkdir(parents=True, exist_ok=True)
60-
with open(self._audit_file, "a") as f:
61-
f.write(record.model_dump_json() + "\n")
79+
self._dir_created = True
80+
return True
6281
except OSError as e:
63-
import logging
82+
_logger.error("AUDIT LOG FAILURE: cannot create directory: %s", e)
83+
return False
84+
85+
def _write_record(self, record: AuditRecord) -> None:
86+
"""Synchronous, thread-safe write of a single audit record."""
87+
with self._lock:
88+
# Check symlink safety on every write (not cached) to detect
89+
# post-init symlink swaps on the file or any parent directory
90+
if not self._is_path_safe():
91+
_logger.error(
92+
"AUDIT LOG REFUSED: path contains symlink: %s",
93+
self._audit_file,
94+
)
95+
return
96+
if not self._ensure_dir():
97+
return
98+
try:
99+
with open(self._audit_file, "a") as f:
100+
f.write(record.model_dump_json() + "\n")
101+
except OSError as e:
102+
_logger.error("AUDIT LOG FAILURE: %s", e)
64103

65-
logging.getLogger(__name__).error("AUDIT LOG FAILURE: %s", e)
104+
def log(self, *, tool: str, action: str, **kwargs: Any) -> AuditRecord:
105+
"""Write an audit record as a JSON line (synchronous)."""
106+
record = AuditRecord(tool=tool, action=action, **kwargs)
107+
self._write_record(record)
108+
return record
109+
110+
async def async_log(self, *, tool: str, action: str, **kwargs: Any) -> AuditRecord:
111+
"""Write an audit record without blocking the event loop."""
112+
record = AuditRecord(tool=tool, action=action, **kwargs)
113+
await asyncio.to_thread(self._write_record, record)
66114
return record

src/comfyui_mcp/server.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22

33
from __future__ import annotations
44

5+
import atexit
6+
import contextlib
57
from pathlib import Path
68

9+
import httpx
710
from mcp.server.fastmcp import FastMCP
811

912
from comfyui_mcp.audit import AuditLogger
@@ -83,6 +86,7 @@ def _register_all_tools(
8386
download_validator: DownloadValidator,
8487
model_checker: ModelChecker,
8588
model_search_settings: ModelSearchSettings,
89+
search_http: httpx.AsyncClient,
8690
) -> None:
8791
"""Register all MCP tool groups with their dependencies."""
8892
register_discovery_tools(server, client, audit, rate_limiters["read"], sanitizer, node_auditor)
@@ -118,10 +122,13 @@ def _register_all_tools(
118122
detector=detector,
119123
validator=download_validator,
120124
search_settings=model_search_settings,
125+
search_http=search_http,
121126
)
122127

123128

124-
def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
129+
def _build_server(
130+
settings: Settings | None = None,
131+
) -> tuple[FastMCP, Settings, ComfyUIClient, httpx.AsyncClient]:
125132
"""Build and configure the MCP server with all tools registered."""
126133
if settings is None:
127134
settings = load_settings()
@@ -144,6 +151,7 @@ def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
144151
allowed_extensions=settings.security.allowed_model_extensions,
145152
)
146153
model_checker = ModelChecker()
154+
search_http = httpx.AsyncClient(timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10))
147155

148156
server_kwargs: dict = {
149157
"name": "ComfyUI",
@@ -183,13 +191,29 @@ def _build_server(settings: Settings | None = None) -> tuple[FastMCP, Settings]:
183191
download_validator=download_validator,
184192
model_checker=model_checker,
185193
model_search_settings=settings.model_search,
194+
search_http=search_http,
186195
)
187196

188-
return server, settings
197+
return server, settings, client, search_http
189198

190199

191200
# Module-level server instance for import and CLI use
192-
mcp, _settings = _build_server()
201+
mcp, _settings, _client, _search_http = _build_server()
202+
203+
204+
def _cleanup() -> None:
205+
"""Best-effort cleanup of HTTP clients on process exit."""
206+
import asyncio
207+
208+
async def _close() -> None:
209+
await _client.close()
210+
await _search_http.aclose()
211+
212+
with contextlib.suppress(Exception):
213+
asyncio.run(_close())
214+
215+
216+
atexit.register(_cleanup)
193217

194218

195219
def main() -> None:

src/comfyui_mcp/tools/discovery.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def list_models(folder: str = "checkpoints") -> list[str]:
159159
"""List available models in a folder (checkpoints, loras, vae, etc.)."""
160160
limiter.check("list_models")
161161
sanitizer.validate_path_segment(folder, label="folder")
162-
audit.log(tool="list_models", action="called", extra={"folder": folder})
162+
await audit.async_log(tool="list_models", action="called", extra={"folder": folder})
163163
return await client.get_models(folder)
164164

165165
tool_fns["list_models"] = list_models
@@ -168,7 +168,7 @@ async def list_models(folder: str = "checkpoints") -> list[str]:
168168
async def list_nodes() -> list[str]:
169169
"""List all available ComfyUI node types."""
170170
limiter.check("list_nodes")
171-
audit.log(tool="list_nodes", action="called")
171+
await audit.async_log(tool="list_nodes", action="called")
172172
info = await client.get_object_info()
173173
return sorted(info.keys())
174174

@@ -178,7 +178,9 @@ async def list_nodes() -> list[str]:
178178
async def get_node_info(node_class: str) -> dict:
179179
"""Get detailed information about a specific node type."""
180180
limiter.check("get_node_info")
181-
audit.log(tool="get_node_info", action="called", extra={"node_class": node_class})
181+
await audit.async_log(
182+
tool="get_node_info", action="called", extra={"node_class": node_class}
183+
)
182184
return await client.get_object_info(node_class)
183185

184186
tool_fns["get_node_info"] = get_node_info
@@ -187,7 +189,7 @@ async def get_node_info(node_class: str) -> dict:
187189
async def list_workflows() -> list:
188190
"""List available workflow templates."""
189191
limiter.check("list_workflows")
190-
audit.log(tool="list_workflows", action="called")
192+
await audit.async_log(tool="list_workflows", action="called")
191193
return await client.get_workflow_templates()
192194

193195
tool_fns["list_workflows"] = list_workflows
@@ -196,7 +198,7 @@ async def list_workflows() -> list:
196198
async def list_extensions() -> list:
197199
"""List available ComfyUI extensions."""
198200
limiter.check("list_extensions")
199-
audit.log(tool="list_extensions", action="called")
201+
await audit.async_log(tool="list_extensions", action="called")
200202
return await client.get_extensions()
201203

202204
tool_fns["list_extensions"] = list_extensions
@@ -205,7 +207,7 @@ async def list_extensions() -> list:
205207
async def get_server_features() -> dict:
206208
"""Get ComfyUI server features and capabilities."""
207209
limiter.check("get_server_features")
208-
audit.log(tool="get_server_features", action="called")
210+
await audit.async_log(tool="get_server_features", action="called")
209211
return await client.get_features()
210212

211213
tool_fns["get_server_features"] = get_server_features
@@ -214,7 +216,7 @@ async def get_server_features() -> dict:
214216
async def list_model_folders() -> list[str]:
215217
"""List available model folder types (checkpoints, loras, vae, etc.)."""
216218
limiter.check("list_model_folders")
217-
audit.log(tool="list_model_folders", action="called")
219+
await audit.async_log(tool="list_model_folders", action="called")
218220
return await client.get_model_types()
219221

220222
tool_fns["list_model_folders"] = list_model_folders
@@ -230,7 +232,7 @@ async def get_model_metadata(folder: str, filename: str) -> dict:
230232
limiter.check("get_model_metadata")
231233
sanitizer.validate_path_segment(folder, label="folder")
232234
sanitizer.validate_path_segment(filename, label="filename")
233-
audit.log(
235+
await audit.async_log(
234236
tool="get_model_metadata",
235237
action="called",
236238
extra={"folder": folder, "filename": filename},
@@ -250,7 +252,7 @@ async def audit_dangerous_nodes() -> dict:
250252
Dictionary with dangerous and suspicious node counts and lists
251253
"""
252254
limiter.check("audit_dangerous_nodes")
253-
audit.log(tool="audit_dangerous_nodes", action="started")
255+
await audit.async_log(tool="audit_dangerous_nodes", action="started")
254256

255257
auditor = node_auditor if node_auditor else NodeAuditor()
256258

@@ -273,7 +275,7 @@ async def audit_dangerous_nodes() -> dict:
273275
},
274276
}
275277

276-
audit.log(
278+
await audit.async_log(
277279
tool="audit_dangerous_nodes",
278280
action="completed",
279281
extra={
@@ -300,7 +302,7 @@ async def get_system_info() -> dict:
300302
queue (running/pending counts).
301303
"""
302304
limiter.check("get_system_info")
303-
audit.log(tool="get_system_info", action="called")
305+
await audit.async_log(tool="get_system_info", action="called")
304306

305307
raw = await client.get_system_stats()
306308
queue_raw = await client.get_queue()
@@ -353,7 +355,7 @@ async def get_model_presets(
353355
Dictionary containing normalized family and recommended settings.
354356
"""
355357
limiter.check("get_model_presets")
356-
audit.log(
358+
await audit.async_log(
357359
tool="get_model_presets",
358360
action="called",
359361
extra={"model_name": model_name, "model_family": model_family},
@@ -389,7 +391,7 @@ async def get_prompting_guide(model_family: str) -> dict[str, Any]:
389391
"""
390392
limiter.check("get_prompting_guide")
391393
normalized = _normalize_model_family(model_family)
392-
audit.log(
394+
await audit.async_log(
393395
tool="get_prompting_guide",
394396
action="called",
395397
extra={"model_family": normalized},

src/comfyui_mcp/tools/files.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,13 @@ async def upload_image(filename: str, image_data: str, subfolder: str = "") -> s
103103
clean_subfolder = sanitizer.validate_subfolder(subfolder)
104104
raw = base64.b64decode(image_data)
105105
sanitizer.validate_size(len(raw))
106-
audit.log(
106+
await audit.async_log(
107107
tool="upload_image",
108108
action="uploading",
109109
extra={"filename": clean_name, "size_bytes": len(raw)},
110110
)
111111
result = await client.upload_image(raw, clean_name, clean_subfolder)
112-
audit.log(tool="upload_image", action="uploaded", extra={"result": result})
112+
await audit.async_log(tool="upload_image", action="uploaded", extra={"result": result})
113113
return f"Uploaded {result.get('name', clean_name)} to ComfyUI input directory"
114114

115115
tool_fns["upload_image"] = upload_image
@@ -128,7 +128,9 @@ async def get_image(filename: str, subfolder: str = "output") -> str:
128128
limiter.check("get_image")
129129
clean_name = sanitizer.validate_filename(filename)
130130
clean_subfolder = sanitizer.validate_subfolder(subfolder)
131-
audit.log(tool="get_image", action="downloading", extra={"filename": clean_name})
131+
await audit.async_log(
132+
tool="get_image", action="downloading", extra={"filename": clean_name}
133+
)
132134
data, content_type = await client.get_image(clean_name, clean_subfolder)
133135
b64 = base64.b64encode(data).decode()
134136
return f"data:{content_type};base64,{b64}"
@@ -139,7 +141,7 @@ async def get_image(filename: str, subfolder: str = "output") -> str:
139141
async def list_outputs() -> list[str]:
140142
"""List files in ComfyUI's output directory."""
141143
limiter.check("list_outputs")
142-
audit.log(tool="list_outputs", action="called")
144+
await audit.async_log(tool="list_outputs", action="called")
143145
history = await client.get_history()
144146
filenames = set()
145147
for entry in history.values():
@@ -167,13 +169,13 @@ async def upload_mask(filename: str, mask_data: str, subfolder: str = "") -> str
167169
clean_subfolder = sanitizer.validate_subfolder(subfolder)
168170
raw = base64.b64decode(mask_data)
169171
sanitizer.validate_size(len(raw))
170-
audit.log(
172+
await audit.async_log(
171173
tool="upload_mask",
172174
action="uploading",
173175
extra={"filename": clean_name, "size_bytes": len(raw)},
174176
)
175177
result = await client.upload_mask(raw, clean_name, clean_subfolder)
176-
audit.log(tool="upload_mask", action="uploaded", extra={"result": result})
178+
await audit.async_log(tool="upload_mask", action="uploaded", extra={"result": result})
177179
return f"Uploaded mask {result.get('name', clean_name)} to ComfyUI input directory"
178180

179181
tool_fns["upload_mask"] = upload_mask
@@ -197,7 +199,7 @@ async def get_workflow_from_image(filename: str, subfolder: str = "output") -> d
197199
limiter.check("get_workflow_from_image")
198200
clean_name = sanitizer.validate_filename(filename)
199201
clean_subfolder = sanitizer.validate_subfolder(subfolder)
200-
audit.log(
202+
await audit.async_log(
201203
tool="get_workflow_from_image",
202204
action="extracting",
203205
extra={"filename": clean_name, "subfolder": clean_subfolder},
@@ -235,7 +237,7 @@ async def get_workflow_from_image(filename: str, subfolder: str = "output") -> d
235237
else:
236238
message = "No workflow metadata found in this image"
237239

238-
audit.log(
240+
await audit.async_log(
239241
tool="get_workflow_from_image",
240242
action="extracted",
241243
extra={

0 commit comments

Comments
 (0)