diff --git a/shortcodes/stable_diffusion/txt2mask.py b/shortcodes/stable_diffusion/txt2mask.py index 81ade7d..6fa760a 100644 --- a/shortcodes/stable_diffusion/txt2mask.py +++ b/shortcodes/stable_diffusion/txt2mask.py @@ -14,6 +14,8 @@ def run_block(self, pargs, kwargs, context, content): from matplotlib import pyplot as plt import cv2 import numpy + from modules.images import flatten + from modules.shared import opts brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add" self.show = True if "show" in pargs else False @@ -24,6 +26,14 @@ def run_block(self, pargs, kwargs, context, content): if "smoothing" in kwargs: radius = int(kwargs["smoothing"]) smoothing_kernel = numpy.ones((radius,radius),numpy.float32)/(radius*radius) + + + + # Pad the mask by applying a dilation + mask_padding = int(self.Unprompted.parse_advanced(kwargs["padding"],context) if "padding" in kwargs else 0) + padding_dilation_kernel = None + if (mask_padding > 0): + padding_dilation_kernel = numpy.ones((mask_padding, mask_padding), numpy.uint8) prompts = content.split(self.Unprompted.Config.syntax.delimiter) prompt_parts = len(prompts) @@ -60,6 +70,7 @@ def process_mask_parts(these_preds,these_prompt_parts,mode,final_img = None): # TODO: Figure out how to convert the plot above to numpy instead of re-loading image img = cv2.imread(filename) + if padding_dilation_kernel is not None: img = cv2.dilate(img,padding_dilation_kernel,iterations=1) if smoothing_kernel is not None: img = cv2.filter2D(img,-1,smoothing_kernel) gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) @@ -99,7 +110,8 @@ def get_mask(): transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512)), ]) - img = transform(self.Unprompted.shortcode_user_vars["init_images"][0]).unsqueeze(0) + flattened_input = flatten(self.Unprompted.shortcode_user_vars["init_images"][0], opts.img2img_background_color) + img = transform(flattened_input).unsqueeze(0) # predict with torch.no_grad(): @@ -133,16 +145,7 @@ def get_mask(): if (pixel_data[0] == 0 and pixel_data[1] == 0 and pixel_data[2] == 0): black_pixels += 1 subject_size = 1 - black_pixels / total_pixels self.Unprompted.shortcode_user_vars[kwargs["size_var"]] = subject_size - - # Increase mask size with padding - mask_padding = int(self.Unprompted.parse_advanced(kwargs["padding"],context) if "padding" in kwargs else 0) - if (mask_padding > 0): - aspect_ratio = self.Unprompted.shortcode_user_vars["init_images"][0].width / self.Unprompted.shortcode_user_vars["init_images"][0].height - new_width = self.Unprompted.shortcode_user_vars["init_images"][0].width+mask_padding*2 - new_height = round(new_width / aspect_ratio) - final_img = final_img.resize((new_width,new_height)) - final_img = center_crop(final_img,self.Unprompted.shortcode_user_vars["init_images"][0].width,self.Unprompted.shortcode_user_vars["init_images"][0].height) - + return final_img # Set up processor parameters correctly @@ -162,4 +165,4 @@ def after(self,p=None,processed=None): processed.images.append(self.image_mask) self.image_mask = None self.show = False - return processed \ No newline at end of file + return processed