diff --git a/apps/dash-image-segmentation/README.md b/apps/dash-image-segmentation/README.md index e7bf7001b..313051f79 100644 --- a/apps/dash-image-segmentation/README.md +++ b/apps/dash-image-segmentation/README.md @@ -37,7 +37,7 @@ python app.py ## Screenshot -![Screenshot of app](assets/screenshot.png) +![Screenshot of app](assets/github/screenshot.png) ## Acknowledgements diff --git a/apps/dash-image-segmentation/app.py b/apps/dash-image-segmentation/app.py index 155f95fb1..367a0ed29 100644 --- a/apps/dash-image-segmentation/app.py +++ b/apps/dash-image-segmentation/app.py @@ -1,424 +1,41 @@ import plotly.express as px import dash -from dash.dependencies import Input, Output, State -import dash_html_components as html -import dash_core_components as dcc +from dash import Dash, html, dcc, Input, Output, State, callback, callback_context import dash_bootstrap_components as dbc -import plot_common -import json -from shapes_to_segmentations import ( - compute_segmentations, - blend_image_and_classified_regions_pil, -) -from skimage import io as skio -from trainable_segmentation import multiscale_basic_features -import io -import base64 -import PIL.Image -import pickle -from time import time -from joblib import Memory - -memory = Memory("./joblib_cache", bytes_limit=3000000000, verbose=3) - -compute_features = memory.cache(multiscale_basic_features) - -DEFAULT_STROKE_WIDTH = 3 # gives line width of 2^3 = 8 - -DEFAULT_IMAGE_PATH = "assets/segmentation_img.jpg" - -SEG_FEATURE_TYPES = ["intensity", "edges", "texture"] - -# the number of different classes for labels -NUM_LABEL_CLASSES = 5 -DEFAULT_LABEL_CLASS = 0 -class_label_colormap = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2"] -class_labels = list(range(NUM_LABEL_CLASSES)) -# we can't have less colors than classes -assert NUM_LABEL_CLASSES <= len(class_label_colormap) +import dash_mantine_components as dmc -# Font and background colors associated with each theme -text_color = {"dark": "#95969A", "light": "#595959"} -card_color = {"dark": "#2D3038", "light": "#FFFFFF"} - - -def class_to_color(n): - return class_label_colormap[n] - - -def color_to_class(c): - return class_label_colormap.index(c) +from time import time +from utils.helper_functions import ( + class_to_color, + color_to_class, + shapes_to_key, + store_shapes_seg_pair, + look_up_seg, + save_img_classifier, + show_segmentation, +) +from utils.figures import make_default_figure, annotation_react +from utils.components import ( + description, + segmentation, + sidebar, + meta, + header_items, + modal_overlay, +) -img = skio.imread(DEFAULT_IMAGE_PATH) -features_dict = {} +external_stylesheets = [dbc.themes.FLATLY, "assets/css/app.css"] -external_stylesheets = [dbc.themes.BOOTSTRAP, "assets/segmentation-style.css"] app = dash.Dash(__name__, external_stylesheets=external_stylesheets) server = app.server app.title = "Interactive image segmentation based on machine learning" -def make_default_figure( - images=[DEFAULT_IMAGE_PATH], - stroke_color=class_to_color(DEFAULT_LABEL_CLASS), - stroke_width=DEFAULT_STROKE_WIDTH, - shapes=[], -): - fig = plot_common.dummy_fig() - plot_common.add_layout_images_to_fig(fig, images) - fig.update_layout( - { - "dragmode": "drawopenpath", - "shapes": shapes, - "newshape.line.color": stroke_color, - "newshape.line.width": stroke_width, - "margin": dict(l=0, r=0, b=0, t=0, pad=4), - } - ) - return fig - - -def shapes_to_key(shapes): - return json.dumps(shapes) - - -def store_shapes_seg_pair(d, key, seg, remove_old=True): - """ - Stores shapes and segmentation pair in dict d - seg is a PIL.Image object - if remove_old True, deletes all the old keys and values. - """ - bytes_to_encode = io.BytesIO() - seg.save(bytes_to_encode, format="png") - bytes_to_encode.seek(0) - data = base64.b64encode(bytes_to_encode.read()).decode() - if remove_old: - return {key: data} - d[key] = data - return d - - -def look_up_seg(d, key): - """ Returns a PIL.Image object """ - data = d[key] - img_bytes = base64.b64decode(data) - img = PIL.Image.open(io.BytesIO(img_bytes)) - return img - - -# Modal -with open("explanations.md", "r") as f: - howto_md = f.read() - -modal_overlay = dbc.Modal( - [ - dbc.ModalBody(html.Div([dcc.Markdown(howto_md)], id="howto-md")), - dbc.ModalFooter(dbc.Button("Close", id="howto-close", className="howto-bn")), - ], - id="modal", - size="lg", -) - -button_howto = dbc.Button( - "Learn more", - id="howto-open", - outline=True, - color="info", - # Turn off lowercase transformation for class .button in stylesheet - style={"textTransform": "none"}, -) - -button_github = dbc.Button( - "View Code on github", - outline=True, - color="primary", - href="https://github.com/plotly/dash-sample-apps/tree/master/apps/dash-image-segmentation", - id="gh-link", - style={"text-transform": "none"}, -) - -# Header -header = dbc.Navbar( - dbc.Container( - [ - dbc.Row( - [ - dbc.Col( - html.Img( - id="logo", - src=app.get_asset_url("dash-logo-new.png"), - height="30px", - ), - md="auto", - ), - dbc.Col( - [ - html.Div( - [ - html.H3("Interactive Machine Learning"), - html.P("Image segmentation"), - ], - id="app-title", - ) - ], - md=True, - align="center", - ), - ], - align="center", - ), - dbc.Row( - [ - dbc.Col( - [ - dbc.NavbarToggler(id="navbar-toggler"), - dbc.Collapse( - dbc.Nav( - [ - dbc.NavItem(button_howto), - dbc.NavItem(button_github), - ], - navbar=True, - ), - id="navbar-collapse", - navbar=True, - ), - modal_overlay, - ], - md=2, - ), - ], - align="center", - ), - ], - fluid=True, - ), - dark=True, - color="dark", - sticky="top", -) - -# Description -description = dbc.Col( - [ - dbc.Card( - id="description-card", - children=[ - dbc.CardHeader("Explanation"), - dbc.CardBody( - [ - dbc.Row( - [ - dbc.Col( - [ - html.Img( - src="assets/segmentation_img_example_marks.jpg", - width="200px", - ) - ], - md="auto", - ), - dbc.Col( - html.P( - "This is an example of interactive machine learning for image classification. " - "To train the classifier, draw some marks on the picture using different colors for " - 'different parts, like in the example image. Then enable "Show segmentation" to see the ' - "classes a Random Forest Classifier gave to regions of the image, based on the marks you " - "used as a guide. You may add more marks to clarify parts of the image where the " - "classifier was not successful and the classification will update." - ), - md=True, - ), - ] - ), - ] - ), - ], - ) - ], - md=12, -) - -# Image Segmentation -segmentation = [ - dbc.Card( - id="segmentation-card", - children=[ - dbc.CardHeader("Viewer"), - dbc.CardBody( - [ - # Wrap dcc.Loading in a div to force transparency when loading - html.Div( - id="transparent-loader-wrapper", - children=[ - dcc.Loading( - id="segmentations-loading", - type="circle", - children=[ - # Graph - dcc.Graph( - id="graph", - figure=make_default_figure(), - config={ - "modeBarButtonsToAdd": [ - "drawrect", - "drawopenpath", - "eraseshape", - ] - }, - ), - ], - ) - ], - ), - ] - ), - dbc.CardFooter( - [ - # Download links - html.A(id="download", download="classifier.json",), - html.Div( - children=[ - dbc.ButtonGroup( - [ - dbc.Button( - "Download classified image", - id="download-image-button", - outline=True, - ), - dbc.Button( - "Download classifier", - id="download-button", - outline=True, - ), - ], - size="lg", - style={"width": "100%"}, - ), - ], - ), - html.A(id="download-image", download="classified-image.png",), - ] - ), - ], - ) -] - -# sidebar -sidebar = [ - dbc.Card( - id="sidebar-card", - children=[ - dbc.CardHeader("Tools"), - dbc.CardBody( - [ - html.H6("Label class", className="card-title"), - # Label class chosen with buttons - html.Div( - id="label-class-buttons", - children=[ - dbc.Button( - "%2d" % (n,), - id={"type": "label-class-button", "index": n}, - style={"background-color": class_to_color(c)}, - ) - for n, c in enumerate(class_labels) - ], - ), - html.Hr(), - dbc.Form( - [ - dbc.FormGroup( - [ - dbc.Label( - "Width of annotation paintbrush", - html_for="stroke-width", - ), - # Slider for specifying stroke width - dcc.Slider( - id="stroke-width", - min=0, - max=6, - step=0.1, - value=DEFAULT_STROKE_WIDTH, - ), - ] - ), - dbc.FormGroup( - [ - html.H6( - id="stroke-width-display", - className="card-title", - ), - dbc.Label( - "Blurring parameter", - html_for="sigma-range-slider", - ), - dcc.RangeSlider( - id="sigma-range-slider", - min=0.01, - max=20, - step=0.01, - value=[0.5, 16], - ), - ] - ), - dbc.FormGroup( - [ - dbc.Label( - "Select features", - html_for="segmentation-features", - ), - dcc.Checklist( - id="segmentation-features", - options=[ - {"label": l.capitalize(), "value": l} - for l in SEG_FEATURE_TYPES - ], - value=["intensity", "edges"], - ), - ] - ), - # Indicate showing most recently computed segmentation - dcc.Checklist( - id="show-segmentation", - options=[ - { - "label": "Show segmentation", - "value": "Show segmentation", - } - ], - value=[], - ), - ] - ), - ] - ), - ], - ), -] - -meta = [ - html.Div( - id="no-display", - children=[ - # Store for user created masks - # data is a list of dicts describing shapes - dcc.Store(id="masks", data={"shapes": []}), - dcc.Store(id="classifier-store", data={}), - dcc.Store(id="classified-image-store", data=""), - dcc.Store(id="features_hash", data=""), - ], - ), - html.Div(id="download-dummy"), - html.Div(id="download-image-dummy"), -] - app.layout = html.Div( [ - header, + dmc.Header(height=70, padding="md", children=header_items), dbc.Container( [ dbc.Row(description), @@ -434,65 +51,26 @@ def look_up_seg(d, key): ) -# Converts image classifier to a JSON compatible encoding and creates a -# dictionary that can be downloaded -# see use_ml_image_segmentation_classifier.py -def save_img_classifier(clf, label_to_colors_args, segmenter_args): - clfbytes = io.BytesIO() - pickle.dump(clf, clfbytes) - clfb64 = base64.b64encode(clfbytes.getvalue()).decode() - return { - "classifier": clfb64, - "segmenter_args": segmenter_args, - "label_to_colors_args": label_to_colors_args, - } - - -def show_segmentation(image_path, mask_shapes, features, segmenter_args): - """ adds an image showing segmentations to a figure's layout """ - # add 1 because classifier takes 0 to mean no mask - shape_layers = [color_to_class(shape["line"]["color"]) + 1 for shape in mask_shapes] - label_to_colors_args = { - "colormap": class_label_colormap, - "color_class_offset": -1, - } - segimg, _, clf = compute_segmentations( - mask_shapes, - img_path=image_path, - shape_layers=shape_layers, - label_to_colors_args=label_to_colors_args, - features=features, - ) - # get the classifier that we can later store in the Store - classifier = save_img_classifier(clf, label_to_colors_args, segmenter_args) - segimgpng = plot_common.img_array_to_pil_image(segimg) - return (segimgpng, classifier) - - -@app.callback( - [ - Output("graph", "figure"), - Output("masks", "data"), - Output("stroke-width-display", "children"), - Output("classifier-store", "data"), - Output("classified-image-store", "data"), - ], - [ - Input("graph", "relayoutData"), - Input( - {"type": "label-class-button", "index": dash.dependencies.ALL}, - "n_clicks_timestamp", - ), - Input("stroke-width", "value"), - Input("show-segmentation", "value"), - Input("download-button", "n_clicks"), - Input("download-image-button", "n_clicks"), - Input("segmentation-features", "value"), - Input("sigma-range-slider", "value"), - ], - [State("masks", "data"),], +@callback( + Output("graph", "figure"), + Output("masks", "data"), + Output("stroke-width-display", "children"), + Output("classifier-store", "data"), + Output("classified-image-store", "data"), + Input("graph", "relayoutData"), + Input( + {"type": "label-class-button", "index": dash.dependencies.ALL}, + "n_clicks_timestamp", + ), + Input("stroke-width", "value"), + Input("show-segmentation", "value"), + Input("download-button", "n_clicks"), + Input("download-image-button", "n_clicks"), + Input("segmentation-features", "value"), + Input("sigma-range-slider", "value"), + State("masks", "data"), ) -def annotation_react( +def return_annotation_react( graph_relayoutData, any_label_class_button_value, stroke_width_value, @@ -503,83 +81,16 @@ def annotation_react( sigma_range_slider_value, masks_data, ): - classified_image_store_data = dash.no_update - classifier_store_data = dash.no_update - cbcontext = [p["prop_id"] for p in dash.callback_context.triggered][0] - if cbcontext in ["segmentation-features.value", "sigma-range-slider.value"] or ( - ("Show segmentation" in show_segmentation_value) - and (len(masks_data["shapes"]) > 0) - ): - segmentation_features_dict = { - "intensity": False, - "edges": False, - "texture": False, - } - for feat in segmentation_features_value: - segmentation_features_dict[feat] = True - t1 = time() - features = compute_features( - img, - **segmentation_features_dict, - sigma_min=sigma_range_slider_value[0], - sigma_max=sigma_range_slider_value[1], - ) - t2 = time() - print(t2 - t1) - if cbcontext == "graph.relayoutData": - if "shapes" in graph_relayoutData.keys(): - masks_data["shapes"] = graph_relayoutData["shapes"] - else: - return dash.no_update - stroke_width = int(round(2 ** (stroke_width_value))) - # find label class value by finding button with the most recent click - if any_label_class_button_value is None: - label_class_value = DEFAULT_LABEL_CLASS - else: - label_class_value = max( - enumerate(any_label_class_button_value), - key=lambda t: 0 if t[1] is None else t[1], - )[0] - - fig = make_default_figure( - stroke_color=class_to_color(label_class_value), - stroke_width=stroke_width, - shapes=masks_data["shapes"], - ) - # We want the segmentation to be computed - if ("Show segmentation" in show_segmentation_value) and ( - len(masks_data["shapes"]) > 0 - ): - segimgpng = None - try: - feature_opts = dict(segmentation_features_dict=segmentation_features_dict) - feature_opts["sigma_min"] = sigma_range_slider_value[0] - feature_opts["sigma_max"] = sigma_range_slider_value[1] - segimgpng, clf = show_segmentation( - DEFAULT_IMAGE_PATH, masks_data["shapes"], features, feature_opts - ) - if cbcontext == "download-button.n_clicks": - classifier_store_data = clf - if cbcontext == "download-image-button.n_clicks": - classified_image_store_data = plot_common.pil_image_to_uri( - blend_image_and_classified_regions_pil( - PIL.Image.open(DEFAULT_IMAGE_PATH), segimgpng - ) - ) - except ValueError: - # if segmentation fails, draw nothing - pass - images_to_draw = [] - if segimgpng is not None: - images_to_draw = [segimgpng] - fig = plot_common.add_layout_images_to_fig(fig, images_to_draw) - fig.update_layout(uirevision="segmentation") - return ( - fig, + return annotation_react( + graph_relayoutData, + any_label_class_button_value, + stroke_width_value, + show_segmentation_value, + download_button_n_clicks, + download_image_button_n_clicks, + segmentation_features_value, + sigma_range_slider_value, masks_data, - "Current paintbrush width: %d" % (stroke_width,), - classifier_store_data, - classified_image_store_data, ) @@ -608,7 +119,7 @@ def annotation_react( } """, Output("download-image", "href"), - [Input("classified-image-store", "data")], + Input("classified-image-store", "data"), ) # simulate a click on the element when download.href is updated @@ -634,15 +145,15 @@ def annotation_react( } """, Output("download-image-dummy", "children"), - [Input("download-image", "href")], + Input("download-image", "href"), ) - # Callback for modal popup -@app.callback( +@callback( Output("modal", "is_open"), - [Input("howto-open", "n_clicks"), Input("howto-close", "n_clicks")], - [State("modal", "is_open")], + Input("howto-open", "n_clicks"), + Input("howto-close", "n_clicks"), + State("modal", "is_open"), ) def toggle_modal(n1, n2, is_open): if n1 or n2: @@ -651,10 +162,10 @@ def toggle_modal(n1, n2, is_open): # we use a callback to toggle the collapse on small screens -@app.callback( +@callback( Output("navbar-collapse", "is_open"), - [Input("navbar-toggler", "n_clicks")], - [State("navbar-collapse", "is_open")], + Input("navbar-toggler", "n_clicks"), + State("navbar-collapse", "is_open"), ) def toggle_navbar_collapse(n, is_open): if n: diff --git a/apps/dash-image-segmentation/assets/css/app.css b/apps/dash-image-segmentation/assets/css/app.css new file mode 100644 index 000000000..8b3802ab0 --- /dev/null +++ b/apps/dash-image-segmentation/assets/css/app.css @@ -0,0 +1,61 @@ +/* Header */ +.header { + height: 10vh; + display: flex; + padding-left: 2%; + padding-right: 2%; + font-family: playfair display, sans-serif; + font-weight: bold; +} + +.header .header-title { + font-size: 5vh; +} +.subheader-title { + font-size: 1.5vh; +} + +.header-logos { + margin-left: auto; +} +.header-logos img { + margin-left: 3vh !important; + max-height: 5vh; +} + + +/* Demo button css */ +.demo-button { + font-size: 1.5vh; + font-family: Open Sans, sans-serif; + text-decoration: none; + -webkit-align-items: center; + -webkit-box-align: center; + -ms-flex-align: center; + align-items: center; + border-radius: 8px; + font-weight: 700; + -webkit-padding-start: 1rem; + padding-inline-start: 1rem; + -webkit-padding-end: 1rem; + padding-inline-end: 1rem; + color: #ffffff; + letter-spacing: 1.5px; + border: solid 1.5px transparent; + box-shadow: 2px 1000px 1px #0c0c0c inset; + background-image: linear-gradient(135deg, #7A76FF, #7A76FF, #7FE4FF); + -webkit-background-size: 200% 100%; + background-size: 200% 100%; + -webkit-background-position: 99%; + background-position: 99%; + background-origin: border-box; + transition: all .4s ease-in-out; + padding-top: 1vh; + padding-bottom: 1vh; + vertical-align: super; +} + +.demo-button:hover { + color: #7A76FF; + background-position: 0%; +} \ No newline at end of file diff --git a/apps/dash-image-segmentation/assets/segmentation-style.css b/apps/dash-image-segmentation/assets/css/segmentation-style.css similarity index 100% rename from apps/dash-image-segmentation/assets/segmentation-style.css rename to apps/dash-image-segmentation/assets/css/segmentation-style.css diff --git a/apps/dash-image-segmentation/assets/screenshot.png b/apps/dash-image-segmentation/assets/github/screenshot.png similarity index 100% rename from apps/dash-image-segmentation/assets/screenshot.png rename to apps/dash-image-segmentation/assets/github/screenshot.png diff --git a/apps/dash-image-segmentation/assets/dash-logo-new.png b/apps/dash-image-segmentation/assets/images/dash-logo-new.png similarity index 100% rename from apps/dash-image-segmentation/assets/dash-logo-new.png rename to apps/dash-image-segmentation/assets/images/dash-logo-new.png diff --git a/apps/dash-image-segmentation/assets/images/plotly-logo-dark-theme.png b/apps/dash-image-segmentation/assets/images/plotly-logo-dark-theme.png new file mode 100644 index 000000000..984dd57ab Binary files /dev/null and b/apps/dash-image-segmentation/assets/images/plotly-logo-dark-theme.png differ diff --git a/apps/dash-image-segmentation/assets/images/plotly-logo-light-theme.png b/apps/dash-image-segmentation/assets/images/plotly-logo-light-theme.png new file mode 100644 index 000000000..4920c6e34 Binary files /dev/null and b/apps/dash-image-segmentation/assets/images/plotly-logo-light-theme.png differ diff --git a/apps/dash-image-segmentation/assets/segmentation_img.jpg b/apps/dash-image-segmentation/assets/images/segmentation_img.jpg similarity index 100% rename from apps/dash-image-segmentation/assets/segmentation_img.jpg rename to apps/dash-image-segmentation/assets/images/segmentation_img.jpg diff --git a/apps/dash-image-segmentation/assets/segmentation_img_example_marks.jpg b/apps/dash-image-segmentation/assets/images/segmentation_img_example_marks.jpg similarity index 100% rename from apps/dash-image-segmentation/assets/segmentation_img_example_marks.jpg rename to apps/dash-image-segmentation/assets/images/segmentation_img_example_marks.jpg diff --git a/apps/dash-image-segmentation/constants.py b/apps/dash-image-segmentation/constants.py new file mode 100644 index 000000000..f8f795463 --- /dev/null +++ b/apps/dash-image-segmentation/constants.py @@ -0,0 +1,28 @@ +from joblib import Memory +from utils.trainable_segmentation import multiscale_basic_features +from skimage import io as skio + + +memory = Memory("./joblib_cache", bytes_limit=3000000000, verbose=3) + +compute_features = memory.cache(multiscale_basic_features) + +DEFAULT_STROKE_WIDTH = 3 # gives line width of 2^3 = 8 + +DEFAULT_IMAGE_PATH = "assets/images/segmentation_img.jpg" + +SEG_FEATURE_TYPES = ["intensity", "edges", "texture"] + +# the number of different classes for labels +NUM_LABEL_CLASSES = 5 +DEFAULT_LABEL_CLASS = 0 +class_label_colormap = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2"] +class_labels = list(range(NUM_LABEL_CLASSES)) +# we can't have less colors than classes +assert NUM_LABEL_CLASSES <= len(class_label_colormap) + +# Font and background colors associated with each theme +text_color = {"dark": "#95969A", "light": "#595959"} +card_color = {"dark": "#2D3038", "light": "#FFFFFF"} + +img = skio.imread(DEFAULT_IMAGE_PATH) diff --git a/apps/dash-image-segmentation/gitignore b/apps/dash-image-segmentation/gitignore new file mode 100644 index 000000000..d8e187da3 --- /dev/null +++ b/apps/dash-image-segmentation/gitignore @@ -0,0 +1,191 @@ +# .gitignore specifies the files that shouldn't be included +# in version control and therefore shouldn't be included when +# deploying an application to Dash Enterprise +# This is a very exhaustive list! +# This list was based off of https://github.com/github/gitignore + +# Ignore data that is generated during the runtime of an application +# This folder is used by the "Large Data" sample applications +runtime_data/ +data/ + +# Omit SQLite databases that may be produced by dash-snapshots in development +*.db + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + + +# Jupyter Notebook + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + + +# macOS General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +# History files +.Rhistory +.Rapp.history + +# Session Data files +.RData + +# User-specific files +.Ruserdata + +# Example code in package build process +*-Ex.R + +# Output files from R CMD check +/*.Rcheck/ + +# RStudio files +.Rproj.user/ + +# produced vignettes +vignettes/*.html +vignettes/*.pdf + +# OAuth2 token, see https://github.com/hadley/httr/releases/tag/v0.3 +.httr-oauth + +# knitr and R markdown default cache directories +*_cache/ +/cache/ + +# Temporary files created by R markdown +*.utf8.md +*.knit.md + +# R Environment Variables +.Renviron + +# Linux +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +# VSCode +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +*.code-workspace + +# SublineText +# Cache files for Sublime Text +*.tmlanguage.cache +*.tmPreferences.cache +*.stTheme.cache + +# Workspace files are user-specific +*.sublime-workspace + +# Project files should be checked into the repository, unless a significant +# proportion of contributors will probably not be using Sublime Text +# *.sublime-project + +# SFTP configuration file +sftp-config.json + +# Package control specific files +Package Control.last-run +Package Control.ca-list +Package Control.ca-bundle +Package Control.system-ca-bundle +Package Control.cache/ +Package Control.ca-certs/ +Package Control.merged-ca-bundle +Package Control.user-ca-bundle +oscrypto-ca-bundle.crt +bh_unicode_properties.cache + +# Sublime-github package stores a github token in this file +# https://packagecontrol.io/packages/sublime-github +GitHub.sublime-settings \ No newline at end of file diff --git a/apps/dash-image-segmentation/requirements.txt b/apps/dash-image-segmentation/requirements.txt index 5454e9ea2..052ef620a 100644 --- a/apps/dash-image-segmentation/requirements.txt +++ b/apps/dash-image-segmentation/requirements.txt @@ -1,18 +1,14 @@ -dash_core_components==1.10.0 -plotly==4.7.1 -plotly-express==0.4.1 -CairoSVG==2.4.2 +CairoSVG==2.5.2 matplotlib==3.2.1 scipy==1.4.1 -dash==1.12.0 -numpy==1.18.3 scikit_image==0.16.2 -dash_html_components==1.0.3 -dash_table==4.7.0 -Pillow==7.1.2 +Pillow==9.1.1 scikit_learn==0.23.1 scikit-image==0.16.2 -gunicorn==20.0.4 -pandas==1.0.3 joblib==0.15.1 -dash_bootstrap_components==0.10.7 \ No newline at end of file +dash-bootstrap-components==1.0.3 +dash-mantine-components==0.6.0 +# dash_bootstrap_components==0.10.7 +dash==2.4.1 +pandas==1.4.2 +gunicorn==20.1.0 diff --git a/apps/dash-image-segmentation/runtime.txt b/apps/dash-image-segmentation/runtime.txt new file mode 100644 index 000000000..cfa660c42 --- /dev/null +++ b/apps/dash-image-segmentation/runtime.txt @@ -0,0 +1 @@ +python-3.8.0 \ No newline at end of file diff --git a/apps/dash-image-segmentation/utils/components.py b/apps/dash-image-segmentation/utils/components.py new file mode 100644 index 000000000..9f4ed7576 --- /dev/null +++ b/apps/dash-image-segmentation/utils/components.py @@ -0,0 +1,286 @@ +import dash_bootstrap_components as dbc +import dash_mantine_components as dmc +from dash import html, dcc +from utils.figures import make_default_figure +from constants import class_labels, DEFAULT_STROKE_WIDTH, SEG_FEATURE_TYPES +from utils.helper_functions import class_to_color + +# Modal +with open("explanations.md", "r") as f: + howto_md = f.read() + +modal_overlay = dbc.Modal( + [ + dbc.ModalBody(html.Div([dcc.Markdown(howto_md)], id="howto-md")), + dbc.ModalFooter(dbc.Button("Close", id="howto-close", className="howto-bn")), + ], + id="modal", + size="lg", +) + +button_howto = dbc.Button( + "Learn more", + id="howto-open", + outline=False, + color="info", + # Turn off lowercase transformation for class .button in stylesheet + style={"textTransform": "none"}, + size="md", +) + +button_github = dbc.Button( + "Github Code", + outline=False, + color="primary", + href="https://github.com/plotly/dash-sample-apps/tree/master/apps/dash-image-segmentation", + id="gh-link", + style={"text-transform": "none"}, + size="md", +) + +header_items = dmc.Group( + position="apart", + children=[ + dmc.Image( + src="assets/images/plotly-logo-light-theme.png", width=200, height=40 + ), + dmc.Text( + "Dash Image Segmentation", + color="gray", + size="xl", + weight=600, + transform="capitalize", + ), + dbc.NavItem(button_howto), + dbc.NavItem(button_github), + ], +) + +# Description +description = dbc.Col( + [ + dbc.Card( + id="description-card", + children=[ + dbc.CardHeader("Explanation"), + dbc.CardBody( + [ + dbc.Row( + [ + dbc.Col( + [ + html.Img( + src="assets/images/segmentation_img_example_marks.jpg", + width="200px", + ) + ], + md="auto", + ), + dbc.Col( + html.P( + "This is an example of interactive machine learning for image classification. " + "To train the classifier, draw some marks on the picture using different colors for " + 'different parts, like in the example image. Then enable "Show segmentation" to see the ' + "classes a Random Forest Classifier gave to regions of the image, based on the marks you " + "used as a guide. You may add more marks to clarify parts of the image where the " + "classifier was not successful and the classification will update." + ), + md=True, + ), + ] + ), + ] + ), + ], + ) + ], + md=12, +) + +# Image Segmentation +segmentation = [ + dbc.Card( + id="segmentation-card", + children=[ + dbc.CardHeader("Viewer"), + dbc.CardBody( + [ + # Wrap dcc.Loading in a div to force transparency when loading + html.Div( + id="transparent-loader-wrapper", + children=[ + dcc.Loading( + id="segmentations-loading", + type="cube", + children=[ + # Graph + dcc.Graph( + id="graph", + figure=make_default_figure(), + config={ + "modeBarButtonsToAdd": [ + "drawrect", + "drawopenpath", + "eraseshape", + ] + }, + ), + ], + ) + ], + ), + ] + ), + dbc.CardFooter( + [ + # Download links + html.A( + id="download", + download="classifier.json", + ), + html.Div( + children=[ + dbc.ButtonGroup( + [ + dbc.Button( + "Download classified image", + id="download-image-button", + outline=True, + ), + dbc.Button( + "Download classifier", + id="download-button", + outline=True, + ), + ], + size="lg", + style={"width": "100%"}, + ), + ], + ), + html.A( + id="download-image", + download="classified-image.png", + ), + ] + ), + ], + ) +] + +# sidebar +sidebar = [ + dbc.Card( + id="sidebar-card", + children=[ + dbc.CardHeader("Tools"), + dbc.CardBody( + [ + html.H6("Label class", className="card-title"), + # Label class chosen with buttons + html.Div( + id="label-class-buttons", + children=[ + dbc.Button( + "%2d" % (n,), + id={"type": "label-class-button", "index": n}, + style={"background-color": class_to_color(c)}, + ) + for n, c in enumerate(class_labels) + ], + ), + html.Hr(), + dbc.Form( + [ + dbc.Row( + [ + dbc.Label( + "Width of annotation paintbrush", + html_for="stroke-width", + ), + # Slider for specifying stroke width + dcc.Slider( + id="stroke-width", + min=0, + max=6, + step=1, + value=DEFAULT_STROKE_WIDTH, + ), + ] + ), + dbc.Row( + [ + html.H6( + id="stroke-width-display", + className="card-title", + ), + dbc.Label( + "Blurring parameter", + html_for="sigma-range-slider", + ), + dcc.RangeSlider( + id="sigma-range-slider", + min=0.01, + max=20, + step=0.01, + value=[0.5, 16], + ), + ] + ), + dbc.Row( + [ + dbc.Label( + "Select features", + html_for="segmentation-features", + ), + dcc.Checklist( + id="segmentation-features", + options=[ + {"label": l.capitalize(), "value": l} + for l in SEG_FEATURE_TYPES + ], + value=["intensity", "edges"], + ), + ] + ), + # Indicate showing most recently computed segmentation + # dbc.Button( + # name="Show Segmentation", + # size="lg", + # id="show-segmentation", + # value="Show Segmentation", + # type="submit", + # ), + dcc.Checklist( + id="show-segmentation", + options=[ + { + "label": "Show segmentation", + "value": "Show segmentation", + } + ], + value=[], + ), + ] + ), + ] + ), + ], + ), +] + +meta = [ + html.Div( + id="no-display", + children=[ + # Store for user created masks + # data is a list of dicts describing shapes + dcc.Store(id="masks", data={"shapes": []}), + dcc.Store(id="classifier-store", data={}), + dcc.Store(id="classified-image-store", data=""), + dcc.Store(id="features_hash", data=""), + ], + ), + html.Div(id="download-dummy"), + html.Div(id="download-image-dummy"), +] diff --git a/apps/dash-image-segmentation/utils/figures.py b/apps/dash-image-segmentation/utils/figures.py new file mode 100644 index 000000000..8ae69d9cf --- /dev/null +++ b/apps/dash-image-segmentation/utils/figures.py @@ -0,0 +1,128 @@ +import utils.plot_common as plot_common +from constants import ( + DEFAULT_IMAGE_PATH, + DEFAULT_LABEL_CLASS, + DEFAULT_STROKE_WIDTH, +) +import dash +import PIL.Image +from utils.helper_functions import class_to_color, show_segmentation +from time import time +from constants import compute_features, img, DEFAULT_LABEL_CLASS +from utils.shapes_to_segmentations import ( + compute_segmentations, + blend_image_and_classified_regions_pil, +) + +import io + + +def make_default_figure( + images=[DEFAULT_IMAGE_PATH], + stroke_color=class_to_color(DEFAULT_LABEL_CLASS), + stroke_width=DEFAULT_STROKE_WIDTH, + shapes=[], +): + fig = plot_common.dummy_fig() + plot_common.add_layout_images_to_fig(fig, images) + fig.update_layout( + { + "dragmode": "drawopenpath", + "shapes": shapes, + "newshape.line.color": stroke_color, + "newshape.line.width": stroke_width, + "margin": dict(l=0, r=0, b=0, t=0, pad=4), + } + ) + return fig + + +def annotation_react( + graph_relayoutData, + any_label_class_button_value, + stroke_width_value, + show_segmentation_value, + download_button_n_clicks, + download_image_button_n_clicks, + segmentation_features_value, + sigma_range_slider_value, + masks_data, +): + classified_image_store_data = dash.no_update + classifier_store_data = dash.no_update + cbcontext = [p["prop_id"] for p in dash.callback_context.triggered][0] + if cbcontext in ["segmentation-features.value", "sigma-range-slider.value"] or ( + ("Show segmentation" in show_segmentation_value) + and (len(masks_data["shapes"]) > 0) + ): + segmentation_features_dict = { + "intensity": False, + "edges": False, + "texture": False, + } + for feat in segmentation_features_value: + segmentation_features_dict[feat] = True + t1 = time() + features = compute_features( + img, + **segmentation_features_dict, + sigma_min=sigma_range_slider_value[0], + sigma_max=sigma_range_slider_value[1], + ) + t2 = time() + print(t2 - t1) + if cbcontext == "graph.relayoutData": + if "shapes" in graph_relayoutData.keys(): + masks_data["shapes"] = graph_relayoutData["shapes"] + else: + return dash.no_update + stroke_width = int(round(2 ** (stroke_width_value))) + # find label class value by finding button with the most recent click + if any_label_class_button_value is None: + label_class_value = DEFAULT_LABEL_CLASS + else: + label_class_value = max( + enumerate(any_label_class_button_value), + key=lambda t: 0 if t[1] is None else t[1], + )[0] + + fig = make_default_figure( + stroke_color=class_to_color(label_class_value), + stroke_width=stroke_width, + shapes=masks_data["shapes"], + ) + # We want the segmentation to be computed + if ("Show segmentation" in show_segmentation_value) and ( + len(masks_data["shapes"]) > 0 + ): + segimgpng = None + try: + feature_opts = dict(segmentation_features_dict=segmentation_features_dict) + feature_opts["sigma_min"] = sigma_range_slider_value[0] + feature_opts["sigma_max"] = sigma_range_slider_value[1] + segimgpng, clf = show_segmentation( + DEFAULT_IMAGE_PATH, masks_data["shapes"], features, feature_opts + ) + if cbcontext == "download-button.n_clicks": + classifier_store_data = clf + if cbcontext == "download-image-button.n_clicks": + classified_image_store_data = plot_common.pil_image_to_uri( + blend_image_and_classified_regions_pil( + PIL.Image.open(DEFAULT_IMAGE_PATH), segimgpng + ) + ) + except ValueError: + # if segmentation fails, draw nothing + pass + images_to_draw = [] + if segimgpng is not None: + images_to_draw = [segimgpng] + fig = plot_common.add_layout_images_to_fig(fig, images_to_draw) + fig.update_layout(uirevision="segmentation") + return ( + fig, + masks_data, + "Current paintbrush width: %d" % (stroke_width,), + classifier_store_data, + classified_image_store_data, + ) diff --git a/apps/dash-image-segmentation/utils/helper_functions.py b/apps/dash-image-segmentation/utils/helper_functions.py new file mode 100644 index 000000000..167c93d3b --- /dev/null +++ b/apps/dash-image-segmentation/utils/helper_functions.py @@ -0,0 +1,85 @@ +import base64 +import PIL.Image +import json +import utils.plot_common as plot_common +import io +import pickle + + +from constants import class_label_colormap + +from utils.shapes_to_segmentations import ( + compute_segmentations, + blend_image_and_classified_regions_pil, +) + + +def class_to_color(n): + return class_label_colormap[n] + + +def color_to_class(c): + return class_label_colormap.index(c) + + +def shapes_to_key(shapes): + return json.dumps(shapes) + + +def store_shapes_seg_pair(d, key, seg, remove_old=True): + """ + Stores shapes and segmentation pair in dict d + seg is a PIL.Image object + if remove_old True, deletes all the old keys and values. + """ + bytes_to_encode = io.BytesIO() + seg.save(bytes_to_encode, format="png") + bytes_to_encode.seek(0) + data = base64.b64encode(bytes_to_encode.read()).decode() + if remove_old: + return {key: data} + d[key] = data + return d + + +def look_up_seg(d, key): + """Returns a PIL.Image object""" + data = d[key] + img_bytes = base64.b64decode(data) + img = PIL.Image.open(io.BytesIO(img_bytes)) + return img + + +# Converts image classifier to a JSON compatible encoding and creates a +# dictionary that can be downloaded +# see use_ml_image_segmentation_classifier.py +def save_img_classifier(clf, label_to_colors_args, segmenter_args): + clfbytes = io.BytesIO() + pickle.dump(clf, clfbytes) + clfb64 = base64.b64encode(clfbytes.getvalue()).decode() + return { + "classifier": clfb64, + "segmenter_args": segmenter_args, + "label_to_colors_args": label_to_colors_args, + } + + +def show_segmentation(image_path, mask_shapes, features, segmenter_args): + """adds an image showing segmentations to a figure's layout""" + # add 1 because classifier takes 0 to mean no mask + shape_layers = [color_to_class(shape["line"]["color"]) + 1 for shape in mask_shapes] + label_to_colors_args = { + "colormap": class_label_colormap, + "color_class_offset": -1, + } + segimg, _, clf = compute_segmentations( + mask_shapes, + img_path=image_path, + shape_layers=shape_layers, + label_to_colors_args=label_to_colors_args, + features=features, + ) + # get the classifier that we can later store in the Store + classifier = save_img_classifier(clf, label_to_colors_args, segmenter_args) + segimgpng = plot_common.img_array_to_pil_image(segimg) + return (segimgpng, classifier) diff --git a/apps/dash-image-segmentation/plot_common.py b/apps/dash-image-segmentation/utils/plot_common.py similarity index 100% rename from apps/dash-image-segmentation/plot_common.py rename to apps/dash-image-segmentation/utils/plot_common.py diff --git a/apps/dash-image-segmentation/shape_utils.py b/apps/dash-image-segmentation/utils/shape_utils.py similarity index 100% rename from apps/dash-image-segmentation/shape_utils.py rename to apps/dash-image-segmentation/utils/shape_utils.py diff --git a/apps/dash-image-segmentation/shapes_to_segmentations.py b/apps/dash-image-segmentation/utils/shapes_to_segmentations.py similarity index 97% rename from apps/dash-image-segmentation/shapes_to_segmentations.py rename to apps/dash-image-segmentation/utils/shapes_to_segmentations.py index a0165a5b6..5cd708841 100644 --- a/apps/dash-image-segmentation/shapes_to_segmentations.py +++ b/apps/dash-image-segmentation/utils/shapes_to_segmentations.py @@ -4,8 +4,8 @@ import skimage.util import skimage.io import skimage.color -import shape_utils -from trainable_segmentation import fit_segmenter +import utils.shape_utils as shape_utils +from utils.trainable_segmentation import fit_segmenter import plotly.express as px from sklearn.ensemble import RandomForestClassifier from time import time diff --git a/apps/dash-image-segmentation/trainable_segmentation.py b/apps/dash-image-segmentation/utils/trainable_segmentation.py similarity index 100% rename from apps/dash-image-segmentation/trainable_segmentation.py rename to apps/dash-image-segmentation/utils/trainable_segmentation.py diff --git a/apps/dash-image-segmentation/use_ml_image_segmentation_classifier.py b/apps/dash-image-segmentation/utils/use_ml_image_segmentation_classifier.py similarity index 92% rename from apps/dash-image-segmentation/use_ml_image_segmentation_classifier.py rename to apps/dash-image-segmentation/utils/use_ml_image_segmentation_classifier.py index b023e113f..f3e4faf5b 100644 --- a/apps/dash-image-segmentation/use_ml_image_segmentation_classifier.py +++ b/apps/dash-image-segmentation/utils/use_ml_image_segmentation_classifier.py @@ -14,9 +14,9 @@ """ import os -import plot_common -import shapes_to_segmentations -from trainable_segmentation import multiscale_basic_features, predict_segmenter +import utils.plot_common as plot_common +import utils.shapes_to_segmentations as shapes_to_segmentations +from utils.trainable_segmentation import multiscale_basic_features, predict_segmenter import pickle import base64 import io