Skip to content

Commit 4ebbdc8

Browse files
committed
Add blur-type parameter (box/linear/gaussian) for Expand/ShrinkMask nodes
1 parent 1ce339a commit 4ebbdc8

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

nodes.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from . import mat
2222
from .util import (
23+
BlurKernel,
24+
mask_blur,
2325
gaussian_blur,
2426
binary_erosion,
2527
binary_dilation,
@@ -126,6 +128,7 @@ def INPUT_TYPES(s):
126128

127129
def load(self, head: str, patch: str):
128130
head_file = folder_paths.get_full_path("inpaint", head)
131+
assert head_file is not None, f"Inpaint head file not found in inpaint folder: {head}"
129132
inpaint_head_model = InpaintHead()
130133
sd = torch.load(head_file, map_location="cpu", weights_only=True)
131134
inpaint_head_model.load_state_dict(sd)
@@ -486,19 +489,20 @@ def INPUT_TYPES(cls):
486489
"mask": ("MASK",),
487490
"grow": ("INT", {"default": 16, "min": 0, "max": 8096, "step": 1}),
488491
"blur": ("INT", {"default": 7, "min": 0, "max": 8096, "step": 1}),
492+
"blur_type": (["box", "linear", "gaussian"], {"default": "gaussian"}),
489493
}
490494
}
491495

492496
RETURN_TYPES = ("MASK",)
493497
CATEGORY = "inpaint"
494498
FUNCTION = "expand"
495499

496-
def expand(self, mask: Tensor, grow: int, blur: int):
500+
def expand(self, mask: Tensor, grow: int, blur: int, blur_type: str):
497501
mask = mask_unsqueeze(mask)
498502
if grow > 0:
499503
mask = binary_dilation(mask, grow)
500504
if blur > 0:
501-
mask = gaussian_blur(mask, make_odd(blur))
505+
mask = mask_blur(mask, make_odd(blur), BlurKernel[blur_type])
502506
return (mask.squeeze(1),)
503507

504508

@@ -510,17 +514,18 @@ def INPUT_TYPES(cls):
510514
"mask": ("MASK",),
511515
"shrink": ("INT", {"default": 1, "min": 0, "max": 8096, "step": 1}),
512516
"blur": ("INT", {"default": 0, "min": 0, "max": 8096, "step": 1}),
517+
"blur_type": (["box", "linear", "gaussian"], {"default": "gaussian"}),
513518
}
514519
}
515520

516521
RETURN_TYPES = ("MASK",)
517522
CATEGORY = "inpaint"
518523
FUNCTION = "shrink"
519524

520-
def shrink(self, mask: Tensor, shrink: int, blur: int):
525+
def shrink(self, mask: Tensor, shrink: int, blur: int, blur_type: str):
521526
mask = mask_unsqueeze(mask)
522527
if shrink > 0:
523528
mask = binary_erosion(mask, shrink)
524529
if blur > 0:
525-
mask = gaussian_blur(mask, make_odd(blur))
530+
mask = mask_blur(mask, make_odd(blur), BlurKernel[blur_type])
526531
return (mask.squeeze(1),)

util.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from enum import Enum
23
import torch
34
import torch.nn.functional as F
45
import numpy as np
@@ -20,7 +21,7 @@ def to_torch(image: Tensor, mask: Tensor | None = None):
2021
image = image.permute(0, 3, 1, 2) # BHWC -> BCHW
2122
if mask is not None:
2223
mask = mask_unsqueeze(mask)
23-
if image.shape[2:] != mask.shape[2:]:
24+
if mask is not None and image.shape[2:] != mask.shape[2:]:
2425
raise ValueError(
2526
f"Image and mask must be the same size. {image.shape[2:]} != {mask.shape[2:]}"
2627
)
@@ -80,11 +81,11 @@ def undo_resize_square(image: Tensor, original_size: tuple[int, int, int]):
8081
return image[:, :, 0 : prev_size - pad_h, 0 : prev_size - pad_w]
8182

8283

83-
def gaussian_blur(image: Tensor, radius: int, sigma: float = 0):
84+
def gaussian_blur(image: Tensor, size: int, sigma: float = 0):
8485
c = image.shape[-3]
8586
if sigma <= 0:
86-
sigma = 0.3 * (radius - 1) + 0.8
87-
return kornia.filters.gaussian_blur2d(image, (radius, radius), (sigma, sigma))
87+
sigma = 0.3 * (size - 1) + 0.8
88+
return kornia.filters.gaussian_blur2d(image, (size, size), (sigma, sigma))
8889

8990

9091
def binary_erosion(mask: Tensor, radius: int):
@@ -102,6 +103,31 @@ def binary_dilation(mask: Tensor, radius: int):
102103
return mask
103104

104105

106+
class BlurKernel(Enum):
107+
box = "box"
108+
linear = "linear"
109+
gaussian = "gaussian"
110+
111+
112+
def mask_blur(mask: Tensor, size: int, method=BlurKernel.gaussian):
113+
if method is BlurKernel.gaussian:
114+
return gaussian_blur(mask, size)
115+
116+
match method:
117+
case BlurKernel.box:
118+
kernel = torch.ones(1, size, device=mask.device) / size
119+
case BlurKernel.linear:
120+
kernel = torch.linspace(1 / size, (size - 1) / size, size, device=mask.device)
121+
kernel = torch.cat([kernel, torch.ones(1, device=mask.device), kernel.flip(0)])
122+
kernel = kernel / kernel.sum()
123+
kernel = kernel.unsqueeze(0)
124+
case _:
125+
raise ValueError(f"Unknown blur kernel: {method}")
126+
127+
mask = kornia.filters.filter2d_separable(mask, kernel, kernel, border_type="reflect")
128+
return mask
129+
130+
105131
def make_odd(x):
106132
if x > 0 and x % 2 == 0:
107133
return x + 1

0 commit comments

Comments
 (0)