diff --git a/shortcodes/stable_diffusion/txt2mask.py b/shortcodes/stable_diffusion/txt2mask.py index 81ade7d..a2b8c67 100644 --- a/shortcodes/stable_diffusion/txt2mask.py +++ b/shortcodes/stable_diffusion/txt2mask.py @@ -14,6 +14,8 @@ def run_block(self, pargs, kwargs, context, content): from matplotlib import pyplot as plt import cv2 import numpy + from modules.images import flatten + from modules.shared import opts brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add" self.show = True if "show" in pargs else False @@ -99,7 +101,8 @@ def get_mask(): transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512)), ]) - img = transform(self.Unprompted.shortcode_user_vars["init_images"][0]).unsqueeze(0) + flattened_input = flatten(self.Unprompted.shortcode_user_vars["init_images"][0], opts.img2img_background_color) + img = transform(flattened_input).unsqueeze(0) # predict with torch.no_grad(): @@ -162,4 +165,4 @@ def after(self,p=None,processed=None): processed.images.append(self.image_mask) self.image_mask = None self.show = False - return processed \ No newline at end of file + return processed