Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 69 additions & 60 deletions comfy_extras/nodes_compositing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import torch
import comfy.utils
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io


def resize_mask(mask, shape):
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
Expand Down Expand Up @@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
return out_image, out_alpha


class PorterDuffImageComposite:
class PorterDuffImageComposite(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PorterDuffImageComposite",
display_name="Porter-Duff Image Composite",
category="mask/compositing",
inputs=[
io.Image.Input("source"),
io.Mask.Input("source_alpha"),
io.Image.Input("destination"),
io.Mask.Input("destination_alpha"),
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"source": ("IMAGE",),
"source_alpha": ("MASK",),
"destination": ("IMAGE",),
"destination_alpha": ("MASK",),
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
},
}

RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "composite"
CATEGORY = "mask/compositing"

def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
out_images = []
out_alphas = []
Expand Down Expand Up @@ -150,65 +157,67 @@ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destinatio
out_images.append(out_image)
out_alphas.append(out_alpha.squeeze(2))

result = (torch.stack(out_images), torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))


class SplitImageWithAlpha:
class SplitImageWithAlpha(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
}
}

CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "split_image_with_alpha"

def split_image_with_alpha(self, image: torch.Tensor):
def define_schema(cls):
return io.Schema(
node_id="SplitImageWithAlpha",
display_name="Split Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
],
outputs=[
io.Image.Output(),
io.Mask.Output(),
],
)

@classmethod
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
out_images = [i[:,:,:3] for i in image]
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
return result
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))


class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="JoinImageWithAlpha",
display_name="Join Image with Alpha",
category="mask/compositing",
inputs=[
io.Image.Input("image"),
io.Mask.Input("alpha"),
],
outputs=[io.Image.Output()],
)

class JoinImageWithAlpha:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"alpha": ("MASK",),
}
}

CATEGORY = "mask/compositing"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "join_image_with_alpha"

def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []

alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))

result = (torch.stack(out_images),)
return result
return io.NodeOutput(torch.stack(out_images))


NODE_CLASS_MAPPINGS = {
"PorterDuffImageComposite": PorterDuffImageComposite,
"SplitImageWithAlpha": SplitImageWithAlpha,
"JoinImageWithAlpha": JoinImageWithAlpha,
}
class CompositingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
PorterDuffImageComposite,
SplitImageWithAlpha,
JoinImageWithAlpha,
]


NODE_DISPLAY_NAME_MAPPINGS = {
"PorterDuffImageComposite": "Porter-Duff Image Composite",
"SplitImageWithAlpha": "Split Image with Alpha",
"JoinImageWithAlpha": "Join Image with Alpha",
}
async def comfy_entrypoint() -> CompositingExtension:
return CompositingExtension()