Skip to content

Commit 8b90e50

Browse files
Properly handle and reshape masks when used on 3d latents.
1 parent 6ee066a commit 8b90e50

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

comfy/sampler_helpers.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import torch
22
import comfy.model_management
33
import comfy.conds
4+
import comfy.utils
45

56
def prepare_mask(noise_mask, shape, device):
6-
"""ensures noise mask is of proper dimensions"""
7-
noise_mask = torch.nn.functional.interpolate(noise_mask.reshape((-1, 1, noise_mask.shape[-2], noise_mask.shape[-1])), size=(shape[2], shape[3]), mode="bilinear")
8-
noise_mask = torch.cat([noise_mask] * shape[1], dim=1)
9-
noise_mask = comfy.utils.repeat_to_batch_size(noise_mask, shape[0])
10-
noise_mask = noise_mask.to(device)
11-
return noise_mask
7+
return comfy.utils.reshape_mask(noise_mask, shape).to(device)
128

139
def get_models_from_cond(cond, model_type):
1410
models = []

comfy/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,24 @@ def update_absolute(self, value, total=None, preview=None):
848848

849849
def update(self, value):
850850
self.update_absolute(self.current + value)
851+
852+
def reshape_mask(input_mask, output_shape):
853+
dims = len(output_shape) - 2
854+
855+
if dims == 1:
856+
scale_mode = "linear"
857+
858+
if dims == 2:
859+
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
860+
scale_mode = "bilinear"
861+
862+
if dims == 3:
863+
if len(input_mask.shape) < 5:
864+
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
865+
scale_mode = "trilinear"
866+
867+
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
868+
if mask.shape[1] < output_shape[1]:
869+
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
870+
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
871+
return mask

0 commit comments

Comments
 (0)