diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 0f90cf60c9de..d2df07ff9654 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -2,6 +2,8 @@ import comfy_extras.nodes_post_processing import torch import nodes +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io def reshape_latent_to(target_shape, latent, repeat_batch=True): @@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True): return latent -class LatentAdd: +class LatentAdd(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} + def define_schema(cls): + return io.Schema( + node_id="LatentAdd", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -31,19 +39,25 @@ def op(self, samples1, samples2): s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 + s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentSubtract: +class LatentSubtract(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" + def define_schema(cls): + return io.Schema( + node_id="LatentSubtract", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - def op(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -51,41 +65,49 @@ def op(self, samples1, samples2): s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 - s2 - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentMultiply: +class LatentMultiply(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" + def define_schema(cls): + return io.Schema( + node_id="LatentMultiply", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/advanced" - - def op(self, samples, multiplier): + @classmethod + def execute(cls, samples, multiplier) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1 * multiplier - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentInterpolate: +class LatentInterpolate(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), - "samples2": ("LATENT",), - "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + def define_schema(cls): + return io.Schema( + node_id="LatentInterpolate", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2, ratio): + @classmethod + def execute(cls, samples1, samples2, ratio) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -104,19 +126,26 @@ def op(self, samples1, samples2, ratio): st = torch.nan_to_num(t / mt) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentConcat: +class LatentConcat(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" + def define_schema(cls): + return io.Schema( + node_id="LatentConcat", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/advanced" - - def op(self, samples1, samples2, dim): + @classmethod + def execute(cls, samples1, samples2, dim) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] @@ -136,22 +165,27 @@ def op(self, samples1, samples2, dim): dim = -3 samples_out["samples"] = torch.cat(c, dim=dim) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentCut: +class LatentCut(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT",), - "dim": (["x", "y", "t"], ), - "index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), - "amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced" + def define_schema(cls): + return io.Schema( + node_id="LatentCut", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("dim", options=["x", "y", "t"]), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) - def op(self, samples, dim, index, amount): + @classmethod + def execute(cls, samples, dim, index, amount) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] @@ -171,19 +205,25 @@ def op(self, samples, dim, index, amount): amount = min(-index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatch: +class LatentBatch(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "batch" - - CATEGORY = "latent/batch" + def define_schema(cls): + return io.Schema( + node_id="LatentBatch", + category="latent/batch", + inputs=[ + io.Latent.Input("samples1"), + io.Latent.Input("samples2"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - def batch(self, samples1, samples2): + @classmethod + def execute(cls, samples1, samples2) -> io.NodeOutput: samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] @@ -192,20 +232,25 @@ def batch(self, samples1, samples2): s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentBatchSeedBehavior: +class LatentBatchSeedBehavior(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" + def define_schema(cls): + return io.Schema( + node_id="LatentBatchSeedBehavior", + category="latent/advanced", + inputs=[ + io.Latent.Input("samples"), + io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - CATEGORY = "latent/advanced" - - def op(self, samples, seed_behavior): + @classmethod + def execute(cls, samples, seed_behavior) -> io.NodeOutput: samples_out = samples.copy() latent = samples["samples"] if seed_behavior == "random": @@ -215,41 +260,50 @@ def op(self, samples, seed_behavior): batch_number = samples_out.get("batch_index", [0])[0] samples_out["batch_index"] = [batch_number] * latent.shape[0] - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperation: +class LatentApplyOperation(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT",), - "operation": ("LATENT_OPERATION",), - }} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperation", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Latent.Input("samples"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Latent.Output(), + ], + ) - def op(self, samples, operation): + @classmethod + def execute(cls, samples, operation) -> io.NodeOutput: samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = operation(latent=s1) - return (samples_out,) + return io.NodeOutput(samples_out) -class LatentApplyOperationCFG: +class LatentApplyOperationCFG(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "operation": ("LATENT_OPERATION",), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True + def define_schema(cls): + return io.Schema( + node_id="LatentApplyOperationCFG", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.LatentOperation.Input("operation"), + ], + outputs=[ + io.Model.Output(), + ], + ) - def patch(self, model, operation): + @classmethod + def execute(cls, model, operation) -> io.NodeOutput: m = model.clone() def pre_cfg_function(args): @@ -261,21 +315,25 @@ def pre_cfg_function(args): return conds_out m.set_model_sampler_pre_cfg_function(pre_cfg_function) - return (m, ) + return io.NodeOutput(m) -class LatentOperationTonemapReinhard: +class LatentOperationTonemapReinhard(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), - }} - - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" + def define_schema(cls): + return io.Schema( + node_id="LatentOperationTonemapReinhard", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, multiplier): + @classmethod + def execute(cls, multiplier) -> io.NodeOutput: def tonemap_reinhard(latent, **kwargs): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] normalized_latent = latent / latent_vector_magnitude @@ -291,39 +349,27 @@ def tonemap_reinhard(latent, **kwargs): new_magnitude *= top return normalized_latent * new_magnitude - return (tonemap_reinhard,) + return io.NodeOutput(tonemap_reinhard) + +class LatentOperationSharpen(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentOperationSharpen", + category="latent/advanced/operations", + is_experimental=True, + inputs=[ + io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), + io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01), + ], + outputs=[ + io.LatentOperation.Output(), + ], + ) -class LatentOperationSharpen: @classmethod - def INPUT_TYPES(s): - return {"required": { - "sharpen_radius": ("INT", { - "default": 9, - "min": 1, - "max": 31, - "step": 1 - }), - "sigma": ("FLOAT", { - "default": 1.0, - "min": 0.1, - "max": 10.0, - "step": 0.1 - }), - "alpha": ("FLOAT", { - "default": 0.1, - "min": 0.0, - "max": 5.0, - "step": 0.01 - }), - }} - - RETURN_TYPES = ("LATENT_OPERATION",) - FUNCTION = "op" - - CATEGORY = "latent/advanced/operations" - EXPERIMENTAL = True - - def op(self, sharpen_radius, sigma, alpha): + def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput: def sharpen(latent, **kwargs): luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None] normalized_latent = latent / luminance @@ -340,19 +386,27 @@ def sharpen(latent, **kwargs): sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] return luminance * sharpened - return (sharpen,) - -NODE_CLASS_MAPPINGS = { - "LatentAdd": LatentAdd, - "LatentSubtract": LatentSubtract, - "LatentMultiply": LatentMultiply, - "LatentInterpolate": LatentInterpolate, - "LatentConcat": LatentConcat, - "LatentCut": LatentCut, - "LatentBatch": LatentBatch, - "LatentBatchSeedBehavior": LatentBatchSeedBehavior, - "LatentApplyOperation": LatentApplyOperation, - "LatentApplyOperationCFG": LatentApplyOperationCFG, - "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, - "LatentOperationSharpen": LatentOperationSharpen, -} + return io.NodeOutput(sharpen) + + +class LatentExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LatentAdd, + LatentSubtract, + LatentMultiply, + LatentInterpolate, + LatentConcat, + LatentCut, + LatentBatch, + LatentBatchSeedBehavior, + LatentApplyOperation, + LatentApplyOperationCFG, + LatentOperationTonemapReinhard, + LatentOperationSharpen, + ] + + +async def comfy_entrypoint() -> LatentExtension: + return LatentExtension()