|
1 | 1 | from __future__ import annotations |
2 | 2 | from typing import Any |
| 3 | +import kornia |
3 | 4 | import numpy as np |
4 | 5 | import torch |
5 | 6 | import torch.jit |
|
22 | 23 | from . import mat |
23 | 24 | from .util import ( |
24 | 25 | BlurKernel, |
| 26 | + image_to_torch, |
25 | 27 | mask_blur, |
26 | 28 | gaussian_blur, |
27 | 29 | binary_erosion, |
@@ -141,7 +143,7 @@ def execute(cls, head: str, patch: str): # type: ignore |
141 | 143 | patch_file = folder_paths.get_full_path("inpaint", patch) |
142 | 144 | inpaint_lora = comfy.utils.load_torch_file(patch_file, safe_load=True) |
143 | 145 |
|
144 | | - return io.NodeOutput(inpaint_head_model, inpaint_lora) |
| 146 | + return io.NodeOutput((inpaint_head_model, inpaint_lora)) |
145 | 147 |
|
146 | 148 |
|
147 | 149 | class InpaintBlockPatch: |
@@ -476,6 +478,77 @@ def _upscale(cls, upscale_model, image: Tensor, device): |
476 | 478 | return torch.clamp(s, min=0, max=1.0) |
477 | 479 |
|
478 | 480 |
|
| 481 | +class ColorMatch(io.ComfyNode): |
| 482 | + @classmethod |
| 483 | + def define_schema(cls): |
| 484 | + return io.Schema( |
| 485 | + node_id="INPAINT_ColorMatch", |
| 486 | + display_name="Color Match (Masked)", |
| 487 | + category="inpaint", |
| 488 | + inputs=[ |
| 489 | + io.Image.Input("target"), |
| 490 | + io.Image.Input("reference"), |
| 491 | + io.Mask.Input("exclude_mask", optional=True), |
| 492 | + io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), |
| 493 | + ], |
| 494 | + outputs=[io.Image.Output("image")], |
| 495 | + ) |
| 496 | + |
| 497 | + @classmethod |
| 498 | + def execute( # type: ignore |
| 499 | + cls, target: Tensor, reference: Tensor, exclude_mask: Tensor | None, strength: float |
| 500 | + ): |
| 501 | + # from https://github.com/kijai/ComfyUI-KJNodes (GPLv3), modified with mask support |
| 502 | + if strength <= 0.0: |
| 503 | + return io.NodeOutput(target) |
| 504 | + |
| 505 | + device = model_management.get_torch_device() |
| 506 | + |
| 507 | + src_bchw = image_to_torch(target.to(device)) |
| 508 | + ref_bchw = image_to_torch(reference.to(device)) |
| 509 | + Bs, Cs, Hs, Ws = src_bchw.shape |
| 510 | + Br, Cr, Hr, Wr = ref_bchw.shape |
| 511 | + |
| 512 | + src_lab = kornia.color.rgb_to_lab(src_bchw) |
| 513 | + ref_lab = kornia.color.rgb_to_lab(ref_bchw) |
| 514 | + |
| 515 | + src_lab_flat = src_lab_masked = src_lab.view(Bs, Cs, Hs * Ws) |
| 516 | + ref_lab_flat = ref_lab_masked = ref_lab.view(Br, Cr, Hr * Wr) |
| 517 | + |
| 518 | + if exclude_mask is not None: |
| 519 | + mask = mask_to_torch(exclude_mask).to(device) |
| 520 | + Bm, _, Hm, Wm = mask.shape |
| 521 | + src_mask, ref_mask = mask, mask |
| 522 | + if Hm != Hs or Wm != Ws: |
| 523 | + src_mask = F.interpolate(mask, size=(Hs, Ws), mode="bilinear") |
| 524 | + src_mask_flat = src_mask.view(Bm, 1, Hs * Ws) < 0.5 |
| 525 | + if Hr == Hs and Wr == Ws: |
| 526 | + ref_mask_flat = src_mask_flat |
| 527 | + else: |
| 528 | + if Hm != Hr or Wm != Wr: |
| 529 | + ref_mask = F.interpolate(mask, size=(Hr, Wr), mode="bilinear") |
| 530 | + ref_mask_flat = ref_mask.view(Bm, 1, Hr * Wr) < 0.5 |
| 531 | + src_lab_masked = src_lab_flat * src_mask_flat |
| 532 | + ref_lab_masked = ref_lab_flat * ref_mask_flat |
| 533 | + |
| 534 | + src_std, src_mean = torch.std_mean(src_lab_masked, dim=-1, keepdim=True, unbiased=False) |
| 535 | + ref_std, ref_mean = torch.std_mean(ref_lab_masked, dim=-1, keepdim=True, unbiased=False) |
| 536 | + src_std = src_std.clamp_min_(1e-6) |
| 537 | + |
| 538 | + if Br == 1 and Bs > 1: |
| 539 | + ref_mean = ref_mean.expand(Bs, -1, -1) |
| 540 | + ref_std = ref_std.expand(Bs, -1, -1) |
| 541 | + |
| 542 | + corrected_lab_flat = (src_lab_flat - src_mean) * (ref_std / src_std) + ref_mean |
| 543 | + corrected_lab = corrected_lab_flat.view(Bs, Cs, Hs, Ws) |
| 544 | + |
| 545 | + out = kornia.color.lab_to_rgb(corrected_lab) |
| 546 | + if strength < 1.0: |
| 547 | + out = (1.0 - strength) * src_bchw + strength * out |
| 548 | + |
| 549 | + return io.NodeOutput(to_comfy(out).cpu().float().clamp_(0, 1)) |
| 550 | + |
| 551 | + |
479 | 552 | class DenoiseToCompositingMask(io.ComfyNode): |
480 | 553 | @classmethod |
481 | 554 | def define_schema(cls): |
|
0 commit comments