diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index d75b29e606fe..e6c834fb2e26 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -3,64 +3,79 @@ import comfy.model_management import nodes import torch -import comfy_extras.nodes_slg +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_slg import SkipLayerGuidanceDiT -class TripleCLIPLoader: +class TripleCLIPLoader(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name3": (folder_paths.get_filename_list("text_encoders"), ) - }} - RETURN_TYPES = ("CLIP",) - FUNCTION = "load_clip" + def define_schema(cls): + return io.Schema( + node_id="TripleCLIPLoader", + category="advanced/loaders", + description="[Recipes]\n\nsd3: clip-l, clip-g, t5", + inputs=[ + io.Combo.Input("clip_name1", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name2", options=folder_paths.get_filename_list("text_encoders")), + io.Combo.Input("clip_name3", options=folder_paths.get_filename_list("text_encoders")), + ], + outputs=[ + io.Clip.Output(), + ], + ) - CATEGORY = "advanced/loaders" - - DESCRIPTION = "[Recipes]\n\nsd3: clip-l, clip-g, t5" - - def load_clip(self, clip_name1, clip_name2, clip_name3): + @classmethod + def execute(cls, clip_name1, clip_name2, clip_name3) -> io.NodeOutput: clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1) clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2) clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) - return (clip,) - + return io.NodeOutput(clip) -class EmptySD3LatentImage: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptySD3LatentImage(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" + def define_schema(cls): + return io.Schema( + node_id="EmptySD3LatentImage", + category="latent/sd3", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/sd3" + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples":latent}) - def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) +class CLIPTextEncodeSD3(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeSD3", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("clip_g", multiline=True, dynamic_prompts=True), + io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), + io.Combo.Input("empty_padding", options=["none", "empty_prompt"]), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) -class CLIPTextEncodeSD3: @classmethod - def INPUT_TYPES(s): - return {"required": { - "clip": ("CLIP", ), - "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}), - "empty_padding": (["none", "empty_prompt"], ) - }} - RETURN_TYPES = ("CONDITIONING",) - FUNCTION = "encode" - - CATEGORY = "advanced/conditioning" - - def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): + def execute(cls, clip, clip_l, clip_g, t5xxl, empty_padding) -> io.NodeOutput: no_padding = empty_padding == "none" tokens = clip.tokenize(clip_g) @@ -82,57 +97,106 @@ def encode(self, clip, clip_l, clip_g, t5xxl, empty_padding): tokens["l"] += empty["l"] while len(tokens["l"]) > len(tokens["g"]): tokens["g"] += empty["g"] - return (clip.encode_from_tokens_scheduled(tokens), ) + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) -class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): +class ControlNetApplySD3(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ControlNetApplySD3", + display_name="Apply Controlnet with VAE", + category="conditioning/controlnet", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.ControlNet.Input("control_net"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + is_deprecated=True, + ) + @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "control_net": ("CONTROL_NET", ), - "vae": ("VAE", ), - "image": ("IMAGE", ), - "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - CATEGORY = "conditioning/controlnet" - DEPRECATED = True - - -class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT): + def execute(cls, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None) -> io.NodeOutput: + if strength == 0: + return io.NodeOutput(positive, negative) + + control_hint = image.movedim(-1, 1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get('control', None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), + vae=vae, extra_concat=[]) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d['control'] = c_net + d['control_apply_to_uncond'] = False + n = [t[0], d] + c.append(n) + out.append(c) + return io.NodeOutput(out[0], out[1]) + + +class SkipLayerGuidanceSD3(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Experimental implementation by Dango233@StabilityAI. ''' + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceSD3", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + inputs=[ + io.Model.Input("model"), + io.String.Input("layers", default="7, 8, 9", multiline=False), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + is_experimental=True, + ) + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}) - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance_sd3" - - CATEGORY = "advanced/guidance" - - def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent): - return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) - - -NODE_CLASS_MAPPINGS = { - "TripleCLIPLoader": TripleCLIPLoader, - "EmptySD3LatentImage": EmptySD3LatentImage, - "CLIPTextEncodeSD3": CLIPTextEncodeSD3, - "ControlNetApplySD3": ControlNetApplySD3, - "SkipLayerGuidanceSD3": SkipLayerGuidanceSD3, -} - -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "ControlNetApplySD3": "Apply Controlnet with VAE", -} + def execute(cls, model, layers, scale, start_percent, end_percent) -> io.NodeOutput: + return SkipLayerGuidanceDiT().execute(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers) + + +class SD3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TripleCLIPLoader, + EmptySD3LatentImage, + CLIPTextEncodeSD3, + ControlNetApplySD3, + SkipLayerGuidanceSD3, + ] + + +async def comfy_entrypoint() -> SD3Extension: + return SD3Extension() diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index 7adff202eb2f..c3be80d07fc7 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -1,33 +1,40 @@ import comfy.model_patcher import comfy.samplers import re +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class SkipLayerGuidanceDiT: +class SkipLayerGuidanceDiT(io.ComfyNode): ''' Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers. Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377) Original experimental implementation for SD3 by Dango233@StabilityAI. ''' + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiT", + category="advanced/guidance", + description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}), - "start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), - "rescaling_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True - - DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0): + def execute(cls, model, scale, start_percent, end_percent, double_layers="", single_layers="", rescaling_scale=0) -> io.NodeOutput: # check if layer is comma separated integers def skip(args, extra_args): return args @@ -43,7 +50,7 @@ def skip(args, extra_args): single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def post_cfg_function(args): model = args["model"] @@ -76,29 +83,33 @@ def post_cfg_function(args): m = model.clone() m.set_model_sampler_post_cfg_function(post_cfg_function) - return (m, ) + return io.NodeOutput(m) -class SkipLayerGuidanceDiTSimple: +class SkipLayerGuidanceDiTSimple(io.ComfyNode): ''' Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass. ''' @classmethod - def INPUT_TYPES(s): - return {"required": {"model": ("MODEL", ), - "double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "skip_guidance" - EXPERIMENTAL = True - - DESCRIPTION = "Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass." - - CATEGORY = "advanced/guidance" - - def skip_guidance(self, model, start_percent, end_percent, double_layers="", single_layers=""): + def define_schema(cls): + return io.Schema( + node_id="SkipLayerGuidanceDiTSimple", + category="advanced/guidance", + description="Simple version of the SkipLayerGuidanceDiT node that only modifies the uncond pass.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.String.Input("double_layers", default="7, 8, 9"), + io.String.Input("single_layers", default="7, 8, 9"), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, start_percent, end_percent, double_layers="", single_layers="") -> io.NodeOutput: def skip(args, extra_args): return args @@ -113,7 +124,7 @@ def skip(args, extra_args): single_layers = [int(i) for i in single_layers] if len(double_layers) == 0 and len(single_layers) == 0: - return (model, ) + return io.NodeOutput(model) def calc_cond_batch_function(args): x = args["input"] @@ -144,9 +155,17 @@ def calc_cond_batch_function(args): m = model.clone() m.set_model_sampler_calc_cond_batch_function(calc_cond_batch_function) - return (m, ) + return io.NodeOutput(m) + + +class SkipLayerGuidanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SkipLayerGuidanceDiT, + SkipLayerGuidanceDiTSimple, + ] + -NODE_CLASS_MAPPINGS = { - "SkipLayerGuidanceDiT": SkipLayerGuidanceDiT, - "SkipLayerGuidanceDiTSimple": SkipLayerGuidanceDiTSimple, -} +async def comfy_entrypoint() -> SkipLayerGuidanceExtension: + return SkipLayerGuidanceExtension()