Skip to content

Commit d74ecec

Browse files
committed
Add Color Match node with mask support
1 parent f1939c3 commit d74ecec

File tree

2 files changed

+75
-1
lines changed

2 files changed

+75
-1
lines changed

__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
3333
nodes.MaskedBlur,
3434
nodes.LoadInpaintModel,
3535
nodes.InpaintWithModel,
36+
nodes.ColorMatch,
3637
nodes.ExpandMask,
3738
nodes.ShrinkMask,
3839
nodes.StabilizeMask,

nodes.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22
from typing import Any
3+
import kornia
34
import numpy as np
45
import torch
56
import torch.jit
@@ -22,6 +23,7 @@
2223
from . import mat
2324
from .util import (
2425
BlurKernel,
26+
image_to_torch,
2527
mask_blur,
2628
gaussian_blur,
2729
binary_erosion,
@@ -141,7 +143,7 @@ def execute(cls, head: str, patch: str): # type: ignore
141143
patch_file = folder_paths.get_full_path("inpaint", patch)
142144
inpaint_lora = comfy.utils.load_torch_file(patch_file, safe_load=True)
143145

144-
return io.NodeOutput(inpaint_head_model, inpaint_lora)
146+
return io.NodeOutput((inpaint_head_model, inpaint_lora))
145147

146148

147149
class InpaintBlockPatch:
@@ -476,6 +478,77 @@ def _upscale(cls, upscale_model, image: Tensor, device):
476478
return torch.clamp(s, min=0, max=1.0)
477479

478480

481+
class ColorMatch(io.ComfyNode):
482+
@classmethod
483+
def define_schema(cls):
484+
return io.Schema(
485+
node_id="INPAINT_ColorMatch",
486+
display_name="Color Match (Masked)",
487+
category="inpaint",
488+
inputs=[
489+
io.Image.Input("target"),
490+
io.Image.Input("reference"),
491+
io.Mask.Input("exclude_mask", optional=True),
492+
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
493+
],
494+
outputs=[io.Image.Output("image")],
495+
)
496+
497+
@classmethod
498+
def execute( # type: ignore
499+
cls, target: Tensor, reference: Tensor, exclude_mask: Tensor | None, strength: float
500+
):
501+
# from https://github.com/kijai/ComfyUI-KJNodes (GPLv3), modified with mask support
502+
if strength <= 0.0:
503+
return io.NodeOutput(target)
504+
505+
device = model_management.get_torch_device()
506+
507+
src_bchw = image_to_torch(target.to(device))
508+
ref_bchw = image_to_torch(reference.to(device))
509+
Bs, Cs, Hs, Ws = src_bchw.shape
510+
Br, Cr, Hr, Wr = ref_bchw.shape
511+
512+
src_lab = kornia.color.rgb_to_lab(src_bchw)
513+
ref_lab = kornia.color.rgb_to_lab(ref_bchw)
514+
515+
src_lab_flat = src_lab_masked = src_lab.view(Bs, Cs, Hs * Ws)
516+
ref_lab_flat = ref_lab_masked = ref_lab.view(Br, Cr, Hr * Wr)
517+
518+
if exclude_mask is not None:
519+
mask = mask_to_torch(exclude_mask).to(device)
520+
Bm, _, Hm, Wm = mask.shape
521+
src_mask, ref_mask = mask, mask
522+
if Hm != Hs or Wm != Ws:
523+
src_mask = F.interpolate(mask, size=(Hs, Ws), mode="bilinear")
524+
src_mask_flat = src_mask.view(Bm, 1, Hs * Ws) < 0.5
525+
if Hr == Hs and Wr == Ws:
526+
ref_mask_flat = src_mask_flat
527+
else:
528+
if Hm != Hr or Wm != Wr:
529+
ref_mask = F.interpolate(mask, size=(Hr, Wr), mode="bilinear")
530+
ref_mask_flat = ref_mask.view(Bm, 1, Hr * Wr) < 0.5
531+
src_lab_masked = src_lab_flat * src_mask_flat
532+
ref_lab_masked = ref_lab_flat * ref_mask_flat
533+
534+
src_std, src_mean = torch.std_mean(src_lab_masked, dim=-1, keepdim=True, unbiased=False)
535+
ref_std, ref_mean = torch.std_mean(ref_lab_masked, dim=-1, keepdim=True, unbiased=False)
536+
src_std = src_std.clamp_min_(1e-6)
537+
538+
if Br == 1 and Bs > 1:
539+
ref_mean = ref_mean.expand(Bs, -1, -1)
540+
ref_std = ref_std.expand(Bs, -1, -1)
541+
542+
corrected_lab_flat = (src_lab_flat - src_mean) * (ref_std / src_std) + ref_mean
543+
corrected_lab = corrected_lab_flat.view(Bs, Cs, Hs, Ws)
544+
545+
out = kornia.color.lab_to_rgb(corrected_lab)
546+
if strength < 1.0:
547+
out = (1.0 - strength) * src_bchw + strength * out
548+
549+
return io.NodeOutput(to_comfy(out).cpu().float().clamp_(0, 1))
550+
551+
479552
class DenoiseToCompositingMask(io.ComfyNode):
480553
@classmethod
481554
def define_schema(cls):

0 commit comments

Comments
 (0)