diff --git a/README.md b/README.md index 5e34aaf..3c00f0f 100644 --- a/README.md +++ b/README.md @@ -131,7 +131,7 @@ docker run --rm ghcr.io/hybridindie/comfyui-mcp:latest --help | `upscale_image` | Upscale an image using a model-based upscaler. Params: image (filename), upscale_model (default: RealESRGAN_x4plus.pth). | | `run_workflow` | Submit arbitrary ComfyUI workflow JSON. Inspected for dangerous nodes before execution. Set `wait=True` to block until complete and return outputs. | | `summarize_workflow` | Summarize a workflow's structure, data flow, models, and parameters. Supports `format="text"` (default) or `format="mermaid"` for diagram markup. | -| `create_workflow` | Create a workflow from a template (txt2img, img2img, upscale, inpaint, txt2vid_animatediff, txt2vid_wan) with parameter overrides. | +| `create_workflow` | Create a workflow from templates including txt2img/img2img/upscale/inpaint, txt2vid_animatediff/txt2vid_wan, controlnet_canny/controlnet_depth/controlnet_openpose, ip_adapter, lora_stack, face_restore, flux_txt2img, and sdxl_txt2img. | | `modify_workflow` | Apply batch operations (add_node, remove_node, set_input, connect, disconnect) to a workflow. | | `validate_workflow` | Validate workflow structure, server compatibility, and security. | diff --git a/src/comfyui_mcp/server.py b/src/comfyui_mcp/server.py index a717953..b854a4d 100644 --- a/src/comfyui_mcp/server.py +++ b/src/comfyui_mcp/server.py @@ -107,7 +107,7 @@ def _register_all_tools( model_checker=model_checker, sanitizer=sanitizer, ) - register_workflow_tools(server, client, audit, rate_limiters["read"], inspector) + register_workflow_tools(server, client, audit, rate_limiters["read"], inspector, sanitizer) register_model_tools( mcp=server, client=client, diff --git a/src/comfyui_mcp/tools/workflow.py b/src/comfyui_mcp/tools/workflow.py index ee4d79c..e0bb52e 100644 --- a/src/comfyui_mcp/tools/workflow.py +++ b/src/comfyui_mcp/tools/workflow.py @@ -11,10 +11,36 @@ from comfyui_mcp.client import ComfyUIClient from comfyui_mcp.security.inspector import WorkflowInspector from comfyui_mcp.security.rate_limit import RateLimiter +from comfyui_mcp.security.sanitizer import PathSanitizer, PathValidationError from comfyui_mcp.workflow.operations import apply_operations from comfyui_mcp.workflow.templates import create_from_template from comfyui_mcp.workflow.validation import validate_workflow as _validate_workflow +_PATH_LIKE_TEMPLATE_PARAMS = { + "model", + "model_name", + "motion_module", + "controlnet_model", + "ipadapter_model", + "clip_vision_model", + "lora_name", + "face_restore_model", + "image", + "mask", +} + + +def _sanitize_template_params( + param_dict: dict[str, Any], sanitizer: PathSanitizer +) -> dict[str, Any]: + """Sanitize filename-like template params to block traversal/null-byte inputs.""" + sanitized = dict(param_dict) + for key in _PATH_LIKE_TEMPLATE_PARAMS: + value = sanitized.get(key) + if isinstance(value, str): + sanitized[key] = sanitizer.validate_path_segment(value, label=key) + return sanitized + def register_workflow_tools( mcp: FastMCP, @@ -22,6 +48,7 @@ def register_workflow_tools( audit: AuditLogger, limiter: RateLimiter, inspector: WorkflowInspector, + sanitizer: PathSanitizer, ) -> dict[str, Any]: """Register workflow composition tools.""" tool_fns: dict[str, Any] = {} @@ -30,13 +57,16 @@ def register_workflow_tools( async def create_workflow(template: str, params: str = "{}") -> str: """Create a ComfyUI workflow from a template with optional parameter overrides. - Available templates: txt2img, img2img, upscale, inpaint, txt2vid_animatediff, txt2vid_wan. + Available templates: txt2img, img2img, upscale, inpaint, txt2vid_animatediff, + txt2vid_wan, controlnet_canny, controlnet_depth, controlnet_openpose, + ip_adapter, lora_stack, face_restore, flux_txt2img, sdxl_txt2img. Args: template: Template name (e.g. 'txt2img', 'img2img') params: Optional JSON string of parameter overrides. Common params: prompt, negative_prompt, width, height, - steps, cfg, model, denoise. + steps, cfg, model, denoise, controlnet_model, + control_strength, lora_name, lora_strength. """ limiter.check("create_workflow") try: @@ -47,7 +77,12 @@ async def create_workflow(template: str, params: str = "{}") -> str: if not isinstance(param_dict, dict): raise ValueError('params must be a JSON object (e.g. {"key": "value"})') - wf = create_from_template(template, param_dict) + try: + clean_params = _sanitize_template_params(param_dict, sanitizer) + except PathValidationError as e: + raise ValueError(str(e)) from e + + wf = create_from_template(template, clean_params) audit.log( tool="create_workflow", action="created", diff --git a/src/comfyui_mcp/workflow/templates.py b/src/comfyui_mcp/workflow/templates.py index 99f1627..163d042 100644 --- a/src/comfyui_mcp/workflow/templates.py +++ b/src/comfyui_mcp/workflow/templates.py @@ -272,6 +272,358 @@ } +def _build_controlnet_template( + preprocessor_class: str, + preprocessor_inputs: dict[str, Any], + control_net_name: str, + filename_prefix: str, +) -> dict[str, dict[str, Any]]: + """Build a ControlNet template graph with variant-specific preprocessing.""" + return { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}, + }, + "2": { + "class_type": "LoadImage", + "inputs": {"image": "control.png"}, + }, + "3": { + "class_type": preprocessor_class, + "inputs": copy.deepcopy(preprocessor_inputs), + }, + "4": { + "class_type": "ControlNetLoader", + "inputs": {"control_net_name": control_net_name}, + }, + "5": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["1", 1]}, + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "bad quality, blurry", "clip": ["1", 1]}, + }, + "7": { + "class_type": "ControlNetApplyAdvanced", + "inputs": { + "positive": ["5", 0], + "negative": ["6", 0], + "control_net": ["4", 0], + "image": ["3", 0], + "strength": 1.0, + "start_percent": 0.0, + "end_percent": 1.0, + }, + }, + "8": { + "class_type": "EmptyLatentImage", + "inputs": {"width": 512, "height": 512, "batch_size": 1}, + }, + "9": { + "class_type": "KSampler", + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 7.0, + "sampler_name": "euler", + "scheduler": "normal", + "denoise": 1.0, + "model": ["1", 0], + "positive": ["7", 0], + "negative": ["7", 1], + "latent_image": ["8", 0], + }, + }, + "10": { + "class_type": "VAEDecode", + "inputs": {"samples": ["9", 0], "vae": ["1", 2]}, + }, + "11": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": filename_prefix, "images": ["10", 0]}, + }, + } + + +# --- controlnet_canny template --- +_CONTROLNET_CANNY: dict[str, dict[str, Any]] = _build_controlnet_template( + preprocessor_class="CannyEdgePreprocessor", + preprocessor_inputs={ + "image": ["2", 0], + "low_threshold": 100, + "high_threshold": 200, + }, + control_net_name="control_v11p_sd15_canny.safetensors", + filename_prefix="comfyui-mcp-controlnet-canny", +) + +# --- controlnet_depth template --- +_CONTROLNET_DEPTH: dict[str, dict[str, Any]] = _build_controlnet_template( + preprocessor_class="MiDaS-DepthMapPreprocessor", + preprocessor_inputs={ + "image": ["2", 0], + "a": 2.0, + "bg_threshold": 0.1, + }, + control_net_name="control_v11f1p_sd15_depth.safetensors", + filename_prefix="comfyui-mcp-controlnet-depth", +) + +# --- controlnet_openpose template --- +_CONTROLNET_OPENPOSE: dict[str, dict[str, Any]] = _build_controlnet_template( + preprocessor_class="DWPreprocessor", + preprocessor_inputs={ + "image": ["2", 0], + "resolution": 512, + }, + control_net_name="control_v11p_sd15_openpose.safetensors", + filename_prefix="comfyui-mcp-controlnet-openpose", +) + +# --- ip_adapter template --- +_IP_ADAPTER: dict[str, dict[str, Any]] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "sd_xl_base_1.0.safetensors"}, + }, + "2": { + "class_type": "CLIPVisionLoader", + "inputs": {"clip_name": "CLIP-ViT-H-14-laion2B-s32B-b79K.safetensors"}, + }, + "3": { + "class_type": "IPAdapterModelLoader", + "inputs": {"ipadapter_file": "ip-adapter-plus_sdxl_vit-h.safetensors"}, + }, + "4": { + "class_type": "LoadImage", + "inputs": {"image": "reference.png"}, + }, + "5": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["1", 1]}, + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "bad quality, blurry", "clip": ["1", 1]}, + }, + "7": { + "class_type": "IPAdapterApply", + "inputs": { + "model": ["1", 0], + "ipadapter": ["3", 0], + "image": ["4", 0], + "clip_vision": ["2", 0], + "weight": 0.75, + }, + }, + "8": { + "class_type": "EmptyLatentImage", + "inputs": {"width": 1024, "height": 1024, "batch_size": 1}, + }, + "9": { + "class_type": "KSampler", + "inputs": { + "seed": 0, + "steps": 28, + "cfg": 6.0, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["7", 0], + "positive": ["5", 0], + "negative": ["6", 0], + "latent_image": ["8", 0], + }, + }, + "10": { + "class_type": "VAEDecode", + "inputs": {"samples": ["9", 0], "vae": ["1", 2]}, + }, + "11": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfyui-mcp-ipadapter", "images": ["10", 0]}, + }, +} + +# --- lora_stack template --- +_LORA_STACK: dict[str, dict[str, Any]] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "v1-5-pruned-emaonly.safetensors"}, + }, + "2": { + "class_type": "LoraLoader", + "inputs": { + "model": ["1", 0], + "clip": ["1", 1], + "lora_name": "detail-tweaker.safetensors", + "strength_model": 0.75, + "strength_clip": 0.75, + }, + }, + "3": { + "class_type": "LoraLoader", + "inputs": { + "model": ["2", 0], + "clip": ["2", 1], + "lora_name": "style-cinematic.safetensors", + "strength_model": 0.45, + "strength_clip": 0.45, + }, + }, + "4": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["3", 1]}, + }, + "5": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "bad quality, blurry", "clip": ["3", 1]}, + }, + "6": { + "class_type": "EmptyLatentImage", + "inputs": {"width": 768, "height": 768, "batch_size": 1}, + }, + "7": { + "class_type": "KSampler", + "inputs": { + "seed": 0, + "steps": 24, + "cfg": 6.5, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["3", 0], + "positive": ["4", 0], + "negative": ["5", 0], + "latent_image": ["6", 0], + }, + }, + "8": { + "class_type": "VAEDecode", + "inputs": {"samples": ["7", 0], "vae": ["1", 2]}, + }, + "9": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfyui-mcp-lora-stack", "images": ["8", 0]}, + }, +} + +# --- face_restore template --- +_FACE_RESTORE: dict[str, dict[str, Any]] = { + "1": { + "class_type": "LoadImage", + "inputs": {"image": "input.png"}, + }, + "2": { + "class_type": "UpscaleModelLoader", + "inputs": {"model_name": "RealESRGAN_x4plus.pth"}, + }, + "3": { + "class_type": "ImageUpscaleWithModel", + "inputs": {"upscale_model": ["2", 0], "image": ["1", 0]}, + }, + "4": { + "class_type": "FaceRestoreModelLoader", + "inputs": {"model_name": "codeformer.pth"}, + }, + "5": { + "class_type": "FaceRestoreCFWithModel", + "inputs": {"facerestore_model": ["4", 0], "image": ["3", 0], "fidelity": 0.7}, + }, + "6": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfyui-mcp-face-restore", "images": ["5", 0]}, + }, +} + +# --- flux_txt2img template --- +_FLUX_TXT2IMG: dict[str, dict[str, Any]] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "flux1-dev.safetensors"}, + }, + "2": { + "class_type": "EmptyLatentImage", + "inputs": {"width": 1024, "height": 1024, "batch_size": 1}, + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["1", 1]}, + }, + "4": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "low quality, artifacts", "clip": ["1", 1]}, + }, + "5": { + "class_type": "KSampler", + "inputs": { + "seed": 0, + "steps": 20, + "cfg": 1.0, + "sampler_name": "euler", + "scheduler": "simple", + "denoise": 1.0, + "model": ["1", 0], + "positive": ["3", 0], + "negative": ["4", 0], + "latent_image": ["2", 0], + }, + }, + "6": { + "class_type": "VAEDecode", + "inputs": {"samples": ["5", 0], "vae": ["1", 2]}, + }, + "7": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfyui-mcp-flux", "images": ["6", 0]}, + }, +} + +# --- sdxl_txt2img template --- +_SDXL_TXT2IMG: dict[str, dict[str, Any]] = { + "1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "sd_xl_base_1.0.safetensors"}, + }, + "2": { + "class_type": "EmptyLatentImage", + "inputs": {"width": 1024, "height": 1024, "batch_size": 1}, + }, + "3": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "", "clip": ["1", 1]}, + }, + "4": { + "class_type": "CLIPTextEncode", + "inputs": {"text": "bad quality, blurry", "clip": ["1", 1]}, + }, + "5": { + "class_type": "KSampler", + "inputs": { + "seed": 0, + "steps": 30, + "cfg": 6.0, + "sampler_name": "dpmpp_2m", + "scheduler": "karras", + "denoise": 1.0, + "model": ["1", 0], + "positive": ["3", 0], + "negative": ["4", 0], + "latent_image": ["2", 0], + }, + }, + "6": { + "class_type": "VAEDecode", + "inputs": {"samples": ["5", 0], "vae": ["1", 2]}, + }, + "7": { + "class_type": "SaveImage", + "inputs": {"filename_prefix": "comfyui-mcp-sdxl", "images": ["6", 0]}, + }, +} + + # --- Param application --- _PARAM_MAP: dict[str, list[tuple[str, str]]] = { @@ -292,6 +644,18 @@ "image": [("LoadImage", "image")], "mask": [("LoadImageMask", "image")], "fps": [("SaveAnimatedWEBP", "fps")], + "controlnet_model": [("ControlNetLoader", "control_net_name")], + "control_strength": [("ControlNetApplyAdvanced", "strength")], + "ipadapter_model": [("IPAdapterModelLoader", "ipadapter_file")], + "ipadapter_weight": [("IPAdapterApply", "weight")], + "clip_vision_model": [("CLIPVisionLoader", "clip_name")], + "lora_name": [("LoraLoader", "lora_name")], + "lora_strength": [ + ("LoraLoader", "strength_model"), + ("LoraLoader", "strength_clip"), + ], + "face_restore_model": [("FaceRestoreModelLoader", "model_name")], + "face_restore_fidelity": [("FaceRestoreCFWithModel", "fidelity")], } @@ -329,6 +693,14 @@ def _apply_params(wf: dict[str, Any], params: dict[str, Any]) -> None: "inpaint": _INPAINT, "txt2vid_animatediff": _TXT2VID_ANIMATEDIFF, "txt2vid_wan": _TXT2VID_WAN, + "controlnet_canny": _CONTROLNET_CANNY, + "controlnet_depth": _CONTROLNET_DEPTH, + "controlnet_openpose": _CONTROLNET_OPENPOSE, + "ip_adapter": _IP_ADAPTER, + "lora_stack": _LORA_STACK, + "face_restore": _FACE_RESTORE, + "flux_txt2img": _FLUX_TXT2IMG, + "sdxl_txt2img": _SDXL_TXT2IMG, } diff --git a/tests/test_tools_workflow.py b/tests/test_tools_workflow.py index bd0c5ba..945a93c 100644 --- a/tests/test_tools_workflow.py +++ b/tests/test_tools_workflow.py @@ -13,6 +13,7 @@ from comfyui_mcp.client import ComfyUIClient from comfyui_mcp.security.inspector import WorkflowInspector from comfyui_mcp.security.rate_limit import RateLimiter +from comfyui_mcp.security.sanitizer import PathSanitizer from comfyui_mcp.tools.workflow import register_workflow_tools @@ -22,14 +23,32 @@ def components(tmp_path): audit = AuditLogger(audit_file=tmp_path / "audit.log") limiter = RateLimiter(max_per_minute=60) inspector = WorkflowInspector(mode="audit", dangerous_nodes=["EvalNode"], allowed_nodes=[]) - return client, audit, limiter, inspector + sanitizer = PathSanitizer( + allowed_extensions=[ + ".png", + ".jpg", + ".jpeg", + ".webp", + ".gif", + ".json", + ".safetensors", + ".ckpt", + ".pth", + ".pt", + ".onnx", + ".bin", + ".gguf", + ".patch", + ] + ) + return client, audit, limiter, inspector, sanitizer class TestCreateWorkflow: async def test_creates_txt2img(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) result = await tools["create_workflow"](template="txt2img") wf = json.loads(result) class_types = {v["class_type"] for v in wf.values()} @@ -37,37 +56,57 @@ async def test_creates_txt2img(self, components): assert "CheckpointLoaderSimple" in class_types async def test_creates_with_params(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) params = json.dumps({"prompt": "a dog", "steps": 30}) result = await tools["create_workflow"](template="txt2img", params=params) wf = json.loads(result) sampler = next(v for v in wf.values() if v["class_type"] == "KSampler") assert sampler["inputs"]["steps"] == 30 + async def test_creates_expanded_template(self, components): + client, audit, limiter, inspector, sanitizer = components + mcp = FastMCP("test") + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) + params = json.dumps({"control_strength": 0.8}) + result = await tools["create_workflow"](template="controlnet_canny", params=params) + wf = json.loads(result) + class_types = {v["class_type"] for v in wf.values()} + assert "ControlNetApplyAdvanced" in class_types + control_apply = next(v for v in wf.values() if v["class_type"] == "ControlNetApplyAdvanced") + assert control_apply["inputs"]["strength"] == 0.8 + async def test_invalid_template_raises(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) with pytest.raises(ValueError, match="Unknown template"): await tools["create_workflow"](template="nonexistent") async def test_audit_log_written(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) await tools["create_workflow"](template="txt2img") log_lines = audit._audit_file.read_text().strip().split("\n") entries = [json.loads(line) for line in log_lines] assert any(e["tool"] == "create_workflow" for e in entries) + async def test_rejects_path_traversal_param(self, components): + client, audit, limiter, inspector, sanitizer = components + mcp = FastMCP("test") + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) + params = json.dumps({"controlnet_model": "../evil.safetensors"}) + with pytest.raises(ValueError, match=r"path separator|path traversal"): + await tools["create_workflow"](template="controlnet_canny", params=params) + class TestModifyWorkflow: async def test_adds_node(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) wf = json.dumps({"1": {"class_type": "KSampler", "inputs": {}}}) ops = json.dumps([{"op": "add_node", "class_type": "SaveImage"}]) result = await tools["modify_workflow"](workflow=wf, operations=ops) @@ -75,17 +114,17 @@ async def test_adds_node(self, components): assert "2" in modified async def test_invalid_workflow_json_raises(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) ops = json.dumps([{"op": "add_node", "class_type": "SaveImage"}]) with pytest.raises(ValueError, match="Invalid JSON"): await tools["modify_workflow"](workflow="not json", operations=ops) async def test_invalid_operations_json_raises(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) wf = json.dumps({"1": {"class_type": "KSampler", "inputs": {}}}) with pytest.raises(ValueError, match="Invalid JSON"): await tools["modify_workflow"](workflow=wf, operations="not json") @@ -94,7 +133,7 @@ async def test_invalid_operations_json_raises(self, components): class TestValidateWorkflow: @respx.mock async def test_valid_workflow(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components respx.get("http://test:8188/object_info").mock( return_value=httpx.Response( 200, @@ -104,16 +143,16 @@ async def test_valid_workflow(self, components): ) ) mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) wf = json.dumps({"1": {"class_type": "KSampler", "inputs": {}}}) result = await tools["validate_workflow"](workflow=wf) parsed = json.loads(result) assert parsed["valid"] is True async def test_invalid_json_raises(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) with pytest.raises(ValueError, match="Invalid JSON"): await tools["validate_workflow"](workflow="not json") @@ -121,10 +160,10 @@ async def test_invalid_json_raises(self, components): class TestIntegration: @respx.mock async def test_create_modify_validate_roundtrip(self, components): - client, audit, limiter, inspector = components + client, audit, limiter, inspector, sanitizer = components respx.get("http://test:8188/object_info").mock(side_effect=httpx.ConnectError("offline")) mcp = FastMCP("test") - tools = register_workflow_tools(mcp, client, audit, limiter, inspector) + tools = register_workflow_tools(mcp, client, audit, limiter, inspector, sanitizer) # Create created = await tools["create_workflow"]( diff --git a/tests/test_workflow_templates.py b/tests/test_workflow_templates.py index a4703d0..ab36e52 100644 --- a/tests/test_workflow_templates.py +++ b/tests/test_workflow_templates.py @@ -154,5 +154,78 @@ def test_templates_registry_has_all(self): "inpaint", "txt2vid_animatediff", "txt2vid_wan", + "controlnet_canny", + "controlnet_depth", + "controlnet_openpose", + "ip_adapter", + "lora_stack", + "face_restore", + "flux_txt2img", + "sdxl_txt2img", } assert set(TEMPLATES.keys()) == expected + + +class TestExpandedTemplates: + def test_controlnet_canny_has_expected_nodes(self): + wf = create_from_template("controlnet_canny") + class_types = {v["class_type"] for v in wf.values()} + assert "CannyEdgePreprocessor" in class_types + assert "ControlNetLoader" in class_types + assert "ControlNetApplyAdvanced" in class_types + + def test_controlnet_depth_override(self): + wf = create_from_template( + "controlnet_depth", + {"controlnet_model": "custom-depth.safetensors", "control_strength": 0.65}, + ) + loader = _get_nodes_by_type(wf, "ControlNetLoader") + assert loader[0]["inputs"]["control_net_name"] == "custom-depth.safetensors" + apply = _get_nodes_by_type(wf, "ControlNetApplyAdvanced") + assert apply[0]["inputs"]["strength"] == 0.65 + + def test_ip_adapter_weight_override(self): + wf = create_from_template( + "ip_adapter", + { + "ipadapter_model": "ip-adapter_sdxl.safetensors", + "clip_vision_model": "clip-vit-bigg-14.safetensors", + "ipadapter_weight": 0.5, + }, + ) + ip_model = _get_nodes_by_type(wf, "IPAdapterModelLoader") + assert ip_model[0]["inputs"]["ipadapter_file"] == "ip-adapter_sdxl.safetensors" + clip_vision = _get_nodes_by_type(wf, "CLIPVisionLoader") + assert clip_vision[0]["inputs"]["clip_name"] == "clip-vit-bigg-14.safetensors" + apply = _get_nodes_by_type(wf, "IPAdapterApply") + assert apply[0]["inputs"]["weight"] == 0.5 + + def test_lora_stack_overrides(self): + wf = create_from_template( + "lora_stack", + {"lora_name": "my_style.safetensors", "lora_strength": 0.9}, + ) + loras = _get_nodes_by_type(wf, "LoraLoader") + assert loras + for lora in loras: + assert lora["inputs"]["lora_name"] == "my_style.safetensors" + assert lora["inputs"]["strength_model"] == 0.9 + assert lora["inputs"]["strength_clip"] == 0.9 + + def test_face_restore_and_family_specific_templates(self): + face_wf = create_from_template( + "face_restore", + {"model_name": "4x-UltraSharp.pth", "face_restore_model": "gfpgan_v1.4.pth"}, + ) + upscaler = _get_nodes_by_type(face_wf, "UpscaleModelLoader") + assert upscaler[0]["inputs"]["model_name"] == "4x-UltraSharp.pth" + restore = _get_nodes_by_type(face_wf, "FaceRestoreModelLoader") + assert restore[0]["inputs"]["model_name"] == "gfpgan_v1.4.pth" + + flux_wf = create_from_template("flux_txt2img") + flux_sampler = _get_nodes_by_type(flux_wf, "KSampler") + assert flux_sampler[0]["inputs"]["cfg"] == 1.0 + + sdxl_wf = create_from_template("sdxl_txt2img") + sdxl_sampler = _get_nodes_by_type(sdxl_wf, "KSampler") + assert sdxl_sampler[0]["inputs"]["scheduler"] == "karras"