Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ docker run --rm ghcr.io/hybridindie/comfyui-mcp:latest --help
| `download_model` | Download a model via [ComfyUI-Model-Manager](https://github.com/hayden-fr/ComfyUI-Model-Manager). URL and extension validated. |
| `get_download_tasks` | Check status of active model downloads (progress, speed, status). |
| `cancel_download` | Cancel or clean up a model download task. |
| `get_model_presets` | Return recommended sampler/scheduler/steps/CFG defaults for a model family. |
| `get_prompting_guide` | Return model-family prompt engineering tips and negative prompt guidance. |

> **Requires:** [ComfyUI-Model-Manager](https://github.com/hayden-fr/ComfyUI-Model-Manager) installed in your ComfyUI instance. Download tools are gated behind lazy detection — if Model Manager is not installed, these tools return a helpful error message. `search_models` works without it.

Expand Down
198 changes: 198 additions & 0 deletions src/comfyui_mcp/tools/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,136 @@
from comfyui_mcp.security.rate_limit import RateLimiter
from comfyui_mcp.security.sanitizer import PathSanitizer

_SUPPORTED_MODEL_FAMILIES = {"sd15", "sdxl", "flux", "sd3", "cascade"}

_MODEL_FAMILY_ALIASES = {
"sd1.5": "sd15",
"sd 1.5": "sd15",
"stable-diffusion-1.5": "sd15",
"stable diffusion 1.5": "sd15",
"stable_diffusion_1_5": "sd15",
"stable-diffusion-xl": "sdxl",
"stable diffusion xl": "sdxl",
"stable_diffusion_xl": "sdxl",
"flux.1": "flux",
"sd3.5": "sd3",
"stable-cascade": "cascade",
}

_MODEL_PRESETS: dict[str, dict[str, Any]] = {
"sd15": {
"recommended": {
"sampler": "euler_ancestral",
"scheduler": "normal",
"steps": 28,
"cfg": 7.0,
"resolution": "512x768",
"clip_skip": 1,
"notes": "Tag-heavy prompts and negative prompts work well.",
}
},
"sdxl": {
"recommended": {
"sampler": "dpmpp_2m",
"scheduler": "karras",
"steps": 30,
"cfg": 5.5,
"resolution": "1024x1024",
"clip_skip": 1,
"notes": "Prefer natural language prompts with clear scene composition.",
}
},
"flux": {
"recommended": {
"sampler": "euler",
"scheduler": "simple",
"steps": 20,
"cfg": 1.0,
"resolution": "1024x1024",
"clip_skip": 1,
"notes": "Flow-matching models expect low CFG and concise language.",
}
},
"sd3": {
"recommended": {
"sampler": "dpmpp_2m",
"scheduler": "sgm_uniform",
"steps": 28,
"cfg": 4.5,
"resolution": "1024x1024",
"clip_skip": 1,
"notes": "Use detailed, descriptive prompts; avoid over-weighting terms.",
}
},
"cascade": {
"recommended": {
"sampler": "dpmpp_2m",
"scheduler": "simple",
"steps": 24,
"cfg": 4.0,
"resolution": "1024x1024",
"clip_skip": 1,
"notes": "Cascade benefits from broad composition instructions first.",
}
},
}

_PROMPTING_GUIDES: dict[str, dict[str, Any]] = {
"sd15": {
"prompt_structure": "subject, style, lighting, lens/composition, quality tags",
"weight_syntax": "(token:1.2)",
"quality_tags": ["masterpiece", "best quality", "high detail"],
"negative_prompt_tips": "Use negatives for anatomy artifacts and low-quality tokens.",
},
"sdxl": {
"prompt_structure": "subject + environment + mood + camera framing",
"weight_syntax": "(token:1.1)",
"quality_tags": ["cinematic lighting", "high detail", "sharp focus"],
"negative_prompt_tips": "Keep negatives shorter than SD1.5 to avoid over-constraining.",
},
"flux": {
"prompt_structure": "natural language sentence describing subject, setting, and style",
"weight_syntax": "Avoid heavy weighting unless necessary",
"quality_tags": ["natural lighting", "detailed texture"],
"negative_prompt_tips": "Use short negatives only for hard constraints (e.g. watermark).",
},
"sd3": {
"prompt_structure": "clear scene description with explicit style and camera intent",
"weight_syntax": "Light weighting only; rely on plain language first",
"quality_tags": ["balanced composition", "fine detail"],
"negative_prompt_tips": (
"Use focused negatives for specific defects, not long keyword lists."
),
},
"cascade": {
"prompt_structure": "high-level composition first, then style modifiers",
"weight_syntax": "(token:1.1) for minor emphasis",
"quality_tags": ["clean composition", "color harmony"],
"negative_prompt_tips": "Keep negatives concise; tune guidance before adding many tokens.",
},
}


def _normalize_model_family(model_family: str) -> str:
key = model_family.strip().lower()
return _MODEL_FAMILY_ALIASES.get(key, key)


def _infer_model_family(model_name: str) -> str | None:
name = model_name.strip().lower()
checks = [
("flux", "flux"),
("sdxl", "sdxl"),
("sd3", "sd3"),
("cascade", "cascade"),
("dreamshaper", "sd15"),
("anything", "sd15"),
]
for needle, family in checks:
if needle in name:
return family
return None


def register_discovery_tools(
mcp: FastMCP,
Expand Down Expand Up @@ -208,4 +338,72 @@ async def get_system_info() -> dict:

tool_fns["get_system_info"] = get_system_info

@mcp.tool()
async def get_model_presets(
model_name: str | None = None,
model_family: str | None = None,
) -> dict[str, Any]:
"""Return recommended generation presets for a model family.

Args:
model_name: Optional model filename to infer family from.
model_family: Optional explicit family (sd15, sdxl, flux, sd3, cascade).

Returns:
Dictionary containing normalized family and recommended settings.
"""
limiter.check("get_model_presets")
audit.log(
tool="get_model_presets",
action="called",
extra={"model_name": model_name, "model_family": model_family},
)

family: str | None = None
if model_family:
family = _normalize_model_family(model_family)
elif model_name:
family = _infer_model_family(model_name)
if family is None:
raise ValueError(f"Could not infer model family from: {model_name}")
else:
raise ValueError("Provide either model_name or model_family")

if family not in _SUPPORTED_MODEL_FAMILIES:
supported = ", ".join(sorted(_SUPPORTED_MODEL_FAMILIES))
raise ValueError(f"Unknown model family: {family}. Supported families: {supported}")

return {
"family": family,
**_MODEL_PRESETS[family],
}

tool_fns["get_model_presets"] = get_model_presets

@mcp.tool()
async def get_prompting_guide(model_family: str) -> dict[str, Any]:
"""Return prompt-engineering guidance for a model family.

Args:
model_family: Family name (sd15, sdxl, flux, sd3, cascade).
"""
limiter.check("get_prompting_guide")
normalized = _normalize_model_family(model_family)
audit.log(
tool="get_prompting_guide",
action="called",
extra={"model_family": normalized},
)

if normalized not in _SUPPORTED_MODEL_FAMILIES:
supported = ", ".join(sorted(_SUPPORTED_MODEL_FAMILIES))
raise ValueError(f"Unknown model family: {normalized}. Supported families: {supported}")

return {
"family": normalized,
"guide": _PROMPTING_GUIDES[normalized],
}

tool_fns["get_prompting_guide"] = get_prompting_guide

return tool_fns
57 changes: 57 additions & 0 deletions tests/test_tools_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,60 @@ async def test_rate_limit_enforced(self, tmp_path):

with pytest.raises(RateLimitError):
await tools["get_system_info"]()


class TestModelPresetsAndGuides:
async def test_get_model_presets_by_family(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

result = await tools["get_model_presets"](model_family="flux")

assert result["family"] == "flux"
assert result["recommended"]["sampler"] == "euler"
assert result["recommended"]["cfg"] == 1.0

async def test_get_model_presets_infers_family_from_model_name(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

result = await tools["get_model_presets"](model_name="flux1-dev-fp8.safetensors")

assert result["family"] == "flux"

async def test_get_model_presets_rejects_missing_inputs(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

with pytest.raises(ValueError, match="Provide either model_name or model_family"):
await tools["get_model_presets"]()

async def test_get_prompting_guide_returns_data(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

result = await tools["get_prompting_guide"]("sdxl")

assert result["family"] == "sdxl"
assert "prompt_structure" in result["guide"]
assert "negative_prompt_tips" in result["guide"]

async def test_get_prompting_guide_rejects_unknown_family(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

with pytest.raises(ValueError, match="Unknown model family"):
await tools["get_prompting_guide"]("unknown")

async def test_get_model_presets_rejects_unrecognized_model_name(self, components):
client, audit, limiter, sanitizer = components
mcp = FastMCP("test")
tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer)

with pytest.raises(ValueError, match="Could not infer model family from"):
await tools["get_model_presets"](model_name="mystery_model_v1.safetensors")