-
Notifications
You must be signed in to change notification settings - Fork 218
Make it work in new Forge (Gradio 4) #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
fe0c82a
433c6eb
61dba99
ac4db78
470b0f8
e25c240
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |||||||||||||||||||||||||||||||
| from scipy.ndimage import binary_dilation | ||||||||||||||||||||||||||||||||
| from modules import scripts, shared, script_callbacks | ||||||||||||||||||||||||||||||||
| from modules.ui import gr_show | ||||||||||||||||||||||||||||||||
| from modules.ui_components import FormRow | ||||||||||||||||||||||||||||||||
| from modules.ui_components import FormRow, ToolButton | ||||||||||||||||||||||||||||||||
| from modules.safe import unsafe_torch_load, load | ||||||||||||||||||||||||||||||||
| from modules.processing import StableDiffusionProcessingImg2Img, StableDiffusionProcessing | ||||||||||||||||||||||||||||||||
| from modules.devices import device, torch_gc, cpu | ||||||||||||||||||||||||||||||||
|
|
@@ -21,6 +21,18 @@ | |||||||||||||||||||||||||||||||
| from scripts.auto import clear_sem_sam_cache, register_auto_sam, semantic_segmentation, sem_sam_garbage_collect, image_layer_internal, categorical_mask_image | ||||||||||||||||||||||||||||||||
| from scripts.process_params import SAMProcessUnit, max_cn_num | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| import importlib.metadata | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def get_gradio_version(): | ||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| version = importlib.metadata.version("gradio") | ||||||||||||||||||||||||||||||||
| return version | ||||||||||||||||||||||||||||||||
| except importlib.metadata.PackageNotFoundError: | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Example usage | ||||||||||||||||||||||||||||||||
| gradio_version = get_gradio_version() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| refresh_symbol = '\U0001f504' # 🔄 | ||||||||||||||||||||||||||||||||
| sam_model_cache = OrderedDict() | ||||||||||||||||||||||||||||||||
|
|
@@ -35,16 +47,6 @@ | |||||||||||||||||||||||||||||||
| txt2img_height: gr.Slider = None | ||||||||||||||||||||||||||||||||
| img2img_width: gr.Slider = None | ||||||||||||||||||||||||||||||||
| img2img_height: gr.Slider = None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| class ToolButton(gr.Button, gr.components.FormComponent): | ||||||||||||||||||||||||||||||||
| """Small button with single emoji as text, fits inside gradio forms""" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def __init__(self, **kwargs): | ||||||||||||||||||||||||||||||||
| super().__init__(variant="tool", **kwargs) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def get_block_name(self): | ||||||||||||||||||||||||||||||||
| return "button" | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def show_masks(image_np, masks: np.ndarray, alpha=0.5): | ||||||||||||||||||||||||||||||||
|
|
@@ -56,7 +58,38 @@ def show_masks(image_np, masks: np.ndarray, alpha=0.5): | |||||||||||||||||||||||||||||||
| return image.astype(np.uint8) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
| def update_mask_gr_four(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
| print("Dilation Amount: ", dilation_amt) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Check if mask_gallery is a list | ||||||||||||||||||||||||||||||||
| if isinstance(mask_gallery, list): | ||||||||||||||||||||||||||||||||
| # Extract the image from the (image, caption) tuple | ||||||||||||||||||||||||||||||||
| image_tuple = mask_gallery[chosen_mask + 3] | ||||||||||||||||||||||||||||||||
| if isinstance(image_tuple, tuple) and len(image_tuple) == 2: | ||||||||||||||||||||||||||||||||
| mask_image = image_tuple[0] # Get the image part of the tuple | ||||||||||||||||||||||||||||||||
| if isinstance(mask_image, str): # If it's a file path | ||||||||||||||||||||||||||||||||
| mask_image = Image.open(mask_image) | ||||||||||||||||||||||||||||||||
| elif isinstance(mask_image, np.ndarray): # If it's an ndarray | ||||||||||||||||||||||||||||||||
| mask_image = Image.fromarray(mask_image) | ||||||||||||||||||||||||||||||||
| # Else, it is already a PIL.Image | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| raise TypeError("Expected a tuple (image, caption) in mask_gallery") | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| # If not a list, use the provided mask_gallery as-is | ||||||||||||||||||||||||||||||||
| mask_image = mask_gallery | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| binary_img = np.array(mask_image.convert('1')) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if dilation_amt: | ||||||||||||||||||||||||||||||||
| mask_image, binary_img = dilate_mask(binary_img, dilation_amt) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| blended_image = Image.fromarray(show_masks(np.array(input_image), binary_img.astype(np.bool_)[None, ...])) | ||||||||||||||||||||||||||||||||
| matted_image = np.array(input_image) | ||||||||||||||||||||||||||||||||
| matted_image[~binary_img] = np.array([0, 0, 0, 0]) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| return [blended_image, mask_image, Image.fromarray(matted_image)] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image): | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
| def update_mask_gr_three(mask_gallery, chosen_mask, dilation_amt, input_image): | |
| def update_mask(mask_gallery, chosen_mask, dilation_amt, input_image): | |
| print("Dilation Amount: ", dilation_amt) | |
| if isinstance(mask_gallery, list): | |
| image_data = mask_gallery[chosen_mask + 3] | |
| # In Gradio 4 or above, this is an (image, caption) tuple | |
| if isinstance(image_data, tuple) and len(image_data) == 2: | |
| mask_image = Image.open(image_data[0]) | |
| # In Gradio 3 this is a dict with 'name', 'data', 'is_file' keys | |
| elif isinstance(image_data, dict) and 'name' in image_data: | |
| mask_image = Image.open(image_data['name']) | |
| else: | |
| raise TypeError("Cannot locate mask image while expanding mask") | |
| else: | |
| mask_image = mask_gallery |
This implies that we would crash if np.ndarray or PIL images are passed from a Gradio 4 gallery, but I actually haven't seen this case while using the extension.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any specific reason for the eval ? It's not great for security reasons. In theory this should be json compliant and json.loads should do the trick.
| points = eval(points) # Convert string representation of list back to list | |
| points = json.loads(points) # Convert string representation of list back to list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you seen a case where while using this extension we're stumbling upon
np.ndarrayor PIL images ? I've only filename strings passed here.