Skip to content

Commit 08a83e3

Browse files
committed
updates blend to merge different size imgs
1 parent d0e69db commit 08a83e3

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

post_processing/blend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn.functional as F
23

34
class Blend:
45
def __init__(self):
@@ -26,6 +27,9 @@ def INPUT_TYPES(s):
2627
CATEGORY = "postprocessing"
2728

2829
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
30+
if image1.shape != image2.shape:
31+
image2 = self.crop_and_resize(image2, image1.shape)
32+
2933
blended_image = self.blend_mode(image1, image2, blend_mode)
3034
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
3135
blended_image = torch.clamp(blended_image, 0, 1)
@@ -48,6 +52,29 @@ def blend_mode(self, img1, img2, mode):
4852
def g(self, x):
4953
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
5054

55+
def crop_and_resize(self, img: torch.Tensor, target_shape: tuple):
56+
batch_size, img_h, img_w, img_c = img.shape
57+
_, target_h, target_w, _ = target_shape
58+
img_aspect_ratio = img_w / img_h
59+
target_aspect_ratio = target_w / target_h
60+
61+
# Crop center of the image to the target aspect ratio
62+
if img_aspect_ratio > target_aspect_ratio:
63+
new_width = int(img_h * target_aspect_ratio)
64+
left = (img_w - new_width) // 2
65+
img = img[:, :, left:left + new_width, :]
66+
else:
67+
new_height = int(img_w / target_aspect_ratio)
68+
top = (img_h - new_height) // 2
69+
img = img[:, top:top + new_height, :, :]
70+
71+
# Resize to target size
72+
img = img.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
73+
img = F.interpolate(img, size=(target_h, target_w), mode='bilinear', align_corners=False)
74+
img = img.permute(0, 2, 3, 1)
75+
76+
return img
77+
5178
NODE_CLASS_MAPPINGS = {
5279
"Blend": Blend,
5380
}

post_processing/blur.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def INPUT_TYPES(s):
3131
CATEGORY = "postprocessing"
3232

3333
def gaussian_kernel(self, kernel_size: int, sigma: float):
34-
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size))
34+
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij")
3535
d = torch.sqrt(x * x + y * y)
3636
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
3737
return g / g.sum()

0 commit comments

Comments
 (0)