diff --git a/README.md b/README.md index e8db1a3..5e34aaf 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ docker run --rm ghcr.io/hybridindie/comfyui-mcp:latest --help | `download_model` | Download a model via [ComfyUI-Model-Manager](https://github.com/hayden-fr/ComfyUI-Model-Manager). URL and extension validated. | | `get_download_tasks` | Check status of active model downloads (progress, speed, status). | | `cancel_download` | Cancel or clean up a model download task. | +| `get_model_presets` | Return recommended sampler/scheduler/steps/CFG defaults for a model family. | +| `get_prompting_guide` | Return model-family prompt engineering tips and negative prompt guidance. | > **Requires:** [ComfyUI-Model-Manager](https://github.com/hayden-fr/ComfyUI-Model-Manager) installed in your ComfyUI instance. Download tools are gated behind lazy detection — if Model Manager is not installed, these tools return a helpful error message. `search_models` works without it. diff --git a/src/comfyui_mcp/tools/discovery.py b/src/comfyui_mcp/tools/discovery.py index da0b25a..e339d56 100644 --- a/src/comfyui_mcp/tools/discovery.py +++ b/src/comfyui_mcp/tools/discovery.py @@ -12,6 +12,136 @@ from comfyui_mcp.security.rate_limit import RateLimiter from comfyui_mcp.security.sanitizer import PathSanitizer +_SUPPORTED_MODEL_FAMILIES = {"sd15", "sdxl", "flux", "sd3", "cascade"} + +_MODEL_FAMILY_ALIASES = { + "sd1.5": "sd15", + "sd 1.5": "sd15", + "stable-diffusion-1.5": "sd15", + "stable diffusion 1.5": "sd15", + "stable_diffusion_1_5": "sd15", + "stable-diffusion-xl": "sdxl", + "stable diffusion xl": "sdxl", + "stable_diffusion_xl": "sdxl", + "flux.1": "flux", + "sd3.5": "sd3", + "stable-cascade": "cascade", +} + +_MODEL_PRESETS: dict[str, dict[str, Any]] = { + "sd15": { + "recommended": { + "sampler": "euler_ancestral", + "scheduler": "normal", + "steps": 28, + "cfg": 7.0, + "resolution": "512x768", + "clip_skip": 1, + "notes": "Tag-heavy prompts and negative prompts work well.", + } + }, + "sdxl": { + "recommended": { + "sampler": "dpmpp_2m", + "scheduler": "karras", + "steps": 30, + "cfg": 5.5, + "resolution": "1024x1024", + "clip_skip": 1, + "notes": "Prefer natural language prompts with clear scene composition.", + } + }, + "flux": { + "recommended": { + "sampler": "euler", + "scheduler": "simple", + "steps": 20, + "cfg": 1.0, + "resolution": "1024x1024", + "clip_skip": 1, + "notes": "Flow-matching models expect low CFG and concise language.", + } + }, + "sd3": { + "recommended": { + "sampler": "dpmpp_2m", + "scheduler": "sgm_uniform", + "steps": 28, + "cfg": 4.5, + "resolution": "1024x1024", + "clip_skip": 1, + "notes": "Use detailed, descriptive prompts; avoid over-weighting terms.", + } + }, + "cascade": { + "recommended": { + "sampler": "dpmpp_2m", + "scheduler": "simple", + "steps": 24, + "cfg": 4.0, + "resolution": "1024x1024", + "clip_skip": 1, + "notes": "Cascade benefits from broad composition instructions first.", + } + }, +} + +_PROMPTING_GUIDES: dict[str, dict[str, Any]] = { + "sd15": { + "prompt_structure": "subject, style, lighting, lens/composition, quality tags", + "weight_syntax": "(token:1.2)", + "quality_tags": ["masterpiece", "best quality", "high detail"], + "negative_prompt_tips": "Use negatives for anatomy artifacts and low-quality tokens.", + }, + "sdxl": { + "prompt_structure": "subject + environment + mood + camera framing", + "weight_syntax": "(token:1.1)", + "quality_tags": ["cinematic lighting", "high detail", "sharp focus"], + "negative_prompt_tips": "Keep negatives shorter than SD1.5 to avoid over-constraining.", + }, + "flux": { + "prompt_structure": "natural language sentence describing subject, setting, and style", + "weight_syntax": "Avoid heavy weighting unless necessary", + "quality_tags": ["natural lighting", "detailed texture"], + "negative_prompt_tips": "Use short negatives only for hard constraints (e.g. watermark).", + }, + "sd3": { + "prompt_structure": "clear scene description with explicit style and camera intent", + "weight_syntax": "Light weighting only; rely on plain language first", + "quality_tags": ["balanced composition", "fine detail"], + "negative_prompt_tips": ( + "Use focused negatives for specific defects, not long keyword lists." + ), + }, + "cascade": { + "prompt_structure": "high-level composition first, then style modifiers", + "weight_syntax": "(token:1.1) for minor emphasis", + "quality_tags": ["clean composition", "color harmony"], + "negative_prompt_tips": "Keep negatives concise; tune guidance before adding many tokens.", + }, +} + + +def _normalize_model_family(model_family: str) -> str: + key = model_family.strip().lower() + return _MODEL_FAMILY_ALIASES.get(key, key) + + +def _infer_model_family(model_name: str) -> str | None: + name = model_name.strip().lower() + checks = [ + ("flux", "flux"), + ("sdxl", "sdxl"), + ("sd3", "sd3"), + ("cascade", "cascade"), + ("dreamshaper", "sd15"), + ("anything", "sd15"), + ] + for needle, family in checks: + if needle in name: + return family + return None + def register_discovery_tools( mcp: FastMCP, @@ -208,4 +338,72 @@ async def get_system_info() -> dict: tool_fns["get_system_info"] = get_system_info + @mcp.tool() + async def get_model_presets( + model_name: str | None = None, + model_family: str | None = None, + ) -> dict[str, Any]: + """Return recommended generation presets for a model family. + + Args: + model_name: Optional model filename to infer family from. + model_family: Optional explicit family (sd15, sdxl, flux, sd3, cascade). + + Returns: + Dictionary containing normalized family and recommended settings. + """ + limiter.check("get_model_presets") + audit.log( + tool="get_model_presets", + action="called", + extra={"model_name": model_name, "model_family": model_family}, + ) + + family: str | None = None + if model_family: + family = _normalize_model_family(model_family) + elif model_name: + family = _infer_model_family(model_name) + if family is None: + raise ValueError(f"Could not infer model family from: {model_name}") + else: + raise ValueError("Provide either model_name or model_family") + + if family not in _SUPPORTED_MODEL_FAMILIES: + supported = ", ".join(sorted(_SUPPORTED_MODEL_FAMILIES)) + raise ValueError(f"Unknown model family: {family}. Supported families: {supported}") + + return { + "family": family, + **_MODEL_PRESETS[family], + } + + tool_fns["get_model_presets"] = get_model_presets + + @mcp.tool() + async def get_prompting_guide(model_family: str) -> dict[str, Any]: + """Return prompt-engineering guidance for a model family. + + Args: + model_family: Family name (sd15, sdxl, flux, sd3, cascade). + """ + limiter.check("get_prompting_guide") + normalized = _normalize_model_family(model_family) + audit.log( + tool="get_prompting_guide", + action="called", + extra={"model_family": normalized}, + ) + + if normalized not in _SUPPORTED_MODEL_FAMILIES: + supported = ", ".join(sorted(_SUPPORTED_MODEL_FAMILIES)) + raise ValueError(f"Unknown model family: {normalized}. Supported families: {supported}") + + return { + "family": normalized, + "guide": _PROMPTING_GUIDES[normalized], + } + + tool_fns["get_prompting_guide"] = get_prompting_guide + return tool_fns diff --git a/tests/test_tools_discovery.py b/tests/test_tools_discovery.py index 3faaaf4..88eefab 100644 --- a/tests/test_tools_discovery.py +++ b/tests/test_tools_discovery.py @@ -304,3 +304,60 @@ async def test_rate_limit_enforced(self, tmp_path): with pytest.raises(RateLimitError): await tools["get_system_info"]() + + +class TestModelPresetsAndGuides: + async def test_get_model_presets_by_family(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + result = await tools["get_model_presets"](model_family="flux") + + assert result["family"] == "flux" + assert result["recommended"]["sampler"] == "euler" + assert result["recommended"]["cfg"] == 1.0 + + async def test_get_model_presets_infers_family_from_model_name(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + result = await tools["get_model_presets"](model_name="flux1-dev-fp8.safetensors") + + assert result["family"] == "flux" + + async def test_get_model_presets_rejects_missing_inputs(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + with pytest.raises(ValueError, match="Provide either model_name or model_family"): + await tools["get_model_presets"]() + + async def test_get_prompting_guide_returns_data(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + result = await tools["get_prompting_guide"]("sdxl") + + assert result["family"] == "sdxl" + assert "prompt_structure" in result["guide"] + assert "negative_prompt_tips" in result["guide"] + + async def test_get_prompting_guide_rejects_unknown_family(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + with pytest.raises(ValueError, match="Unknown model family"): + await tools["get_prompting_guide"]("unknown") + + async def test_get_model_presets_rejects_unrecognized_model_name(self, components): + client, audit, limiter, sanitizer = components + mcp = FastMCP("test") + tools = register_discovery_tools(mcp, client, audit, limiter, sanitizer) + + with pytest.raises(ValueError, match="Could not infer model family from"): + await tools["get_model_presets"](model_name="mystery_model_v1.safetensors")