Skip to content

Commit ce4cb23

Browse files
Make LatentCompositeMasked work with basic video latents. (#10023)
1 parent c8d2117 commit ce4cb23

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

comfy_extras/nodes_mask.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,38 @@
1212
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
1313
source = source.to(destination.device)
1414
if resize_source:
15-
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear")
15+
source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear")
1616

1717
source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
1818

19-
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier))
20-
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier))
19+
x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier))
20+
y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier))
2121

2222
left, top = (x // multiplier, y // multiplier)
23-
right, bottom = (left + source.shape[3], top + source.shape[2],)
23+
right, bottom = (left + source.shape[-1], top + source.shape[-2],)
2424

2525
if mask is None:
2626
mask = torch.ones_like(source)
2727
else:
2828
mask = mask.to(destination.device, copy=True)
29-
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear")
29+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear")
3030
mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
3131

3232
# calculate the bounds of the source that will be overlapping the destination
3333
# this prevents the source trying to overwrite latent pixels that are out of bounds
3434
# of the destination
35-
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),)
35+
visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),)
3636

3737
mask = mask[:, :, :visible_height, :visible_width]
38+
if mask.ndim < source.ndim:
39+
mask = mask.unsqueeze(1)
40+
3841
inverse_mask = torch.ones_like(mask) - mask
3942

40-
source_portion = mask * source[:, :, :visible_height, :visible_width]
41-
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right]
43+
source_portion = mask * source[..., :visible_height, :visible_width]
44+
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
4245

43-
destination[:, :, top:bottom, left:right] = source_portion + destination_portion
46+
destination[..., top:bottom, left:right] = source_portion + destination_portion
4447
return destination
4548

4649
class LatentCompositeMasked:

0 commit comments

Comments
 (0)