Skip to content

Commit d8bc82f

Browse files
authored
Add feathering to 180 mask node
1 parent 746960e commit d8bc82f

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

nodes.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,7 @@ def INPUT_TYPES(s) -> Dict:
14811481
["360", "180"],
14821482
{"default": "180"},
14831483
),
1484+
"feather": ("INT", {"default": 0}),
14841485
},
14851486
}
14861487

@@ -1492,17 +1493,35 @@ def INPUT_TYPES(s) -> Dict:
14921493
CATEGORY = "pytorch360convert/mask"
14931494

14941495
def mask_180_to_360(
1495-
self, image: torch.Tensor, input_mode: str = "180"
1496+
self, image: torch.Tensor, input_mode: str = "180", feather: int = 0
14961497
) -> Tuple[torch.Tensor]:
14971498
assert image.dim() == 4, f"image should have 4 dimensions, got {image.dim()}"
14981499
_, H, W, _ = image.shape
1500+
1501+
# For 360 input, the valid region is only half width
14991502
if input_mode == "360":
15001503
W = W // 2
1504+
15011505
pad_left = W // 2
15021506
pad_right = W - pad_left
1507+
total_width = W + pad_left + pad_right # == 2*W
15031508

1504-
mask = torch.ones(1, 1, H, W, dtype=image.dtype, device=image.device)
1505-
mask_padded = torch.nn.functional.pad(
1506-
mask, (pad_left, pad_right), mode="constant", value=0.0
1507-
)
1508-
return (mask_padded[:, 0, ...],)
1509+
# Start with zeros everywhere
1510+
mask = torch.zeros(H, total_width, dtype=image.dtype, device=image.device)
1511+
1512+
# Fill the main region with 1.0
1513+
mask[:, pad_left : pad_left + W] = 1.0
1514+
1515+
if feather > 0:
1516+
ramp = torch.linspace(
1517+
0, 1, steps=feather + 1, device=image.device, dtype=image.dtype
1518+
)[1:]
1519+
1520+
# Left feather (in the padded zero region, approaching the mask)
1521+
mask[:, pad_left - feather : pad_left] = ramp
1522+
1523+
# Right feather (in the padded zero region, approaching the mask)
1524+
mask[:, pad_left + W : pad_left + W + feather] = ramp.flip(0)
1525+
1526+
# [1, H, W] mask tensor
1527+
return (mask.unsqueeze(0),)

0 commit comments

Comments
 (0)