|
12 | 12 | def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
13 | 13 | source = source.to(destination.device)
|
14 | 14 | if resize_source:
|
15 |
| - source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") |
| 15 | + source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear") |
16 | 16 |
|
17 | 17 | source = comfy.utils.repeat_to_batch_size(source, destination.shape[0])
|
18 | 18 |
|
19 |
| - x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) |
20 |
| - y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) |
| 19 | + x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier)) |
| 20 | + y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier)) |
21 | 21 |
|
22 | 22 | left, top = (x // multiplier, y // multiplier)
|
23 |
| - right, bottom = (left + source.shape[3], top + source.shape[2],) |
| 23 | + right, bottom = (left + source.shape[-1], top + source.shape[-2],) |
24 | 24 |
|
25 | 25 | if mask is None:
|
26 | 26 | mask = torch.ones_like(source)
|
27 | 27 | else:
|
28 | 28 | mask = mask.to(destination.device, copy=True)
|
29 |
| - mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") |
| 29 | + mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear") |
30 | 30 | mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0])
|
31 | 31 |
|
32 | 32 | # calculate the bounds of the source that will be overlapping the destination
|
33 | 33 | # this prevents the source trying to overwrite latent pixels that are out of bounds
|
34 | 34 | # of the destination
|
35 |
| - visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) |
| 35 | + visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),) |
36 | 36 |
|
37 | 37 | mask = mask[:, :, :visible_height, :visible_width]
|
| 38 | + if mask.ndim < source.ndim: |
| 39 | + mask = mask.unsqueeze(1) |
| 40 | + |
38 | 41 | inverse_mask = torch.ones_like(mask) - mask
|
39 | 42 |
|
40 |
| - source_portion = mask * source[:, :, :visible_height, :visible_width] |
41 |
| - destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] |
| 43 | + source_portion = mask * source[..., :visible_height, :visible_width] |
| 44 | + destination_portion = inverse_mask * destination[..., top:bottom, left:right] |
42 | 45 |
|
43 |
| - destination[:, :, top:bottom, left:right] = source_portion + destination_portion |
| 46 | + destination[..., top:bottom, left:right] = source_portion + destination_portion |
44 | 47 | return destination
|
45 | 48 |
|
46 | 49 | class LatentCompositeMasked:
|
|
0 commit comments