Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions shortcodes/stable_diffusion/txt2mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def run_block(self, pargs, kwargs, context, content):
from matplotlib import pyplot as plt
import cv2
import numpy
from modules.images import flatten
from modules.shared import opts

brush_mask_mode = self.Unprompted.parse_advanced(kwargs["mode"],context) if "mode" in kwargs else "add"
self.show = True if "show" in pargs else False
Expand All @@ -24,6 +26,14 @@ def run_block(self, pargs, kwargs, context, content):
if "smoothing" in kwargs:
radius = int(kwargs["smoothing"])
smoothing_kernel = numpy.ones((radius,radius),numpy.float32)/(radius*radius)



# Pad the mask by applying a dilation
mask_padding = int(self.Unprompted.parse_advanced(kwargs["padding"],context) if "padding" in kwargs else 0)
padding_dilation_kernel = None
if (mask_padding > 0):
padding_dilation_kernel = numpy.ones((mask_padding, mask_padding), numpy.uint8)

prompts = content.split(self.Unprompted.Config.syntax.delimiter)
prompt_parts = len(prompts)
Expand Down Expand Up @@ -60,6 +70,7 @@ def process_mask_parts(these_preds,these_prompt_parts,mode,final_img = None):
# TODO: Figure out how to convert the plot above to numpy instead of re-loading image
img = cv2.imread(filename)

if padding_dilation_kernel is not None: img = cv2.dilate(img,padding_dilation_kernel,iterations=1)
if smoothing_kernel is not None: img = cv2.filter2D(img,-1,smoothing_kernel)

gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
Expand Down Expand Up @@ -99,7 +110,8 @@ def get_mask():
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512)),
])
img = transform(self.Unprompted.shortcode_user_vars["init_images"][0]).unsqueeze(0)
flattened_input = flatten(self.Unprompted.shortcode_user_vars["init_images"][0], opts.img2img_background_color)
img = transform(flattened_input).unsqueeze(0)

# predict
with torch.no_grad():
Expand Down Expand Up @@ -133,16 +145,7 @@ def get_mask():
if (pixel_data[0] == 0 and pixel_data[1] == 0 and pixel_data[2] == 0): black_pixels += 1
subject_size = 1 - black_pixels / total_pixels
self.Unprompted.shortcode_user_vars[kwargs["size_var"]] = subject_size

# Increase mask size with padding
mask_padding = int(self.Unprompted.parse_advanced(kwargs["padding"],context) if "padding" in kwargs else 0)
if (mask_padding > 0):
aspect_ratio = self.Unprompted.shortcode_user_vars["init_images"][0].width / self.Unprompted.shortcode_user_vars["init_images"][0].height
new_width = self.Unprompted.shortcode_user_vars["init_images"][0].width+mask_padding*2
new_height = round(new_width / aspect_ratio)
final_img = final_img.resize((new_width,new_height))
final_img = center_crop(final_img,self.Unprompted.shortcode_user_vars["init_images"][0].width,self.Unprompted.shortcode_user_vars["init_images"][0].height)


return final_img

# Set up processor parameters correctly
Expand All @@ -162,4 +165,4 @@ def after(self,p=None,processed=None):
processed.images.append(self.image_mask)
self.image_mask = None
self.show = False
return processed
return processed