Skip to content

Commit 29debc1

Browse files
Add support for mask overlay (#18)
* add support for overlay (wip) * Some refactoring to better accomodate multiple traces in a figure * add example * Enable setting overlay colormap. Update example to use multiple thresholds. * rename trigger -> refresh * Refactoring to store overlay data at the client * remove refresh signal * deploy version with overlay * mouse up for threshold slider Co-authored-by: Emmanuelle Gouillart <[email protected]>
1 parent ffcb262 commit 29debc1

File tree

4 files changed

+243
-56
lines changed

4 files changed

+243
-56
lines changed

Procfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
web: gunicorn test_deploy.app:server
1+
web: gunicorn examples.threshold_overlay:server

dash_slicer/slicer.py

Lines changed: 177 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from plotly.graph_objects import Figure, Image, Scatter
2+
from plotly.graph_objects import Figure
33
from dash import Dash
44
from dash.dependencies import Input, Output, State, ALL
55
from dash_core_components import Graph, Slider, Store
@@ -51,7 +51,7 @@ def __init__(
5151
origin=None,
5252
axis=0,
5353
reverse_y=True,
54-
scene_id=None
54+
scene_id=None,
5555
):
5656

5757
if not isinstance(app, Dash):
@@ -112,6 +112,11 @@ def axis(self):
112112
"""The axis at which the slicer is slicing."""
113113
return self._axis
114114

115+
@property
116+
def nslices(self):
117+
"""The number of slices for this slicer."""
118+
return self._volume.shape[self._axis]
119+
115120
@property
116121
def graph(self):
117122
"""The dcc.Graph for this slicer."""
@@ -124,11 +129,69 @@ def slider(self):
124129

125130
@property
126131
def stores(self):
127-
"""A list of dcc.Stores that the slicer needs to work. These must
128-
be added to the app layout.
132+
"""A list of dcc.Store objects that the slicer needs to work.
133+
These must be added to the app layout.
129134
"""
130135
return self._stores
131136

137+
@property
138+
def overlay_data(self):
139+
"""A dcc.Store containing the overlay data. The form of this
140+
data is considered an implementation detail; users are expected to use
141+
``create_overlay_data`` to create it.
142+
"""
143+
return self._overlay_data
144+
145+
def create_overlay_data(self, mask, color=(0, 255, 255, 100)):
146+
"""Given a 3D mask array and an index, create an object that
147+
can be used as output for ``slicer.overlay_data``.
148+
"""
149+
# Check the mask
150+
if mask.dtype not in (np.bool, np.uint8):
151+
raise ValueError(f"Mask must have bool or uint8 dtype, not {mask.dtype}.")
152+
if mask.shape != self._volume.shape:
153+
raise ValueError(
154+
f"Overlay must has shape {mask.shape}, but expected {self._volume.shape}"
155+
)
156+
mask = mask.astype(np.uint8, copy=False) # need int to index
157+
158+
# Create a colormap (list) from the given color(s)
159+
# todo: also support hex colors and css color names
160+
color = np.array(color, np.uint8)
161+
if color.ndim == 1:
162+
if color.shape[0] != 4:
163+
raise ValueError("Overlay color must be 4 ints (0..255).")
164+
colormap = [(0, 0, 0, 0), tuple(color)]
165+
elif color.ndim == 2:
166+
if color.shape[1] != 4:
167+
raise ValueError("Overlay colors must be 4 ints (0..255).")
168+
colormap = [tuple(x) for x in color]
169+
else:
170+
raise ValueError(
171+
"Overlay color must be a single color or a list of colors."
172+
)
173+
174+
# Produce slices (base64 png strings)
175+
overlay_slices = []
176+
for index in range(self.nslices):
177+
# Sample the slice
178+
indices = [slice(None), slice(None), slice(None)]
179+
indices[self._axis] = index
180+
im = mask[tuple(indices)]
181+
max_mask = im.max()
182+
if max_mask == 0:
183+
# If the mask is all zeros, we can simply not draw it
184+
overlay_slices.append(None)
185+
else:
186+
# Turn into rgba
187+
while len(colormap) <= max_mask:
188+
colormap.append(colormap[-1])
189+
colormap_arr = np.array(colormap)
190+
rgba = colormap_arr[im]
191+
overlay_slices.append(img_array_to_uri(rgba))
192+
193+
return overlay_slices
194+
132195
def _subid(self, name, use_dict=False):
133196
"""Given a name, get the full id including the context id prefix."""
134197
if use_dict:
@@ -163,15 +226,8 @@ def _create_dash_components(self):
163226
]
164227
info["lowres_size"] = thumbnail_size
165228

166-
# Create traces
167-
# todo: can add "%{z[0]}", but that would be the scaled value ...
168-
image_trace = Image(
169-
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
170-
)
171-
scatter_trace = Scatter(x=[], y=[]) # placeholder
172-
173229
# Create the figure object - can be accessed by user via slicer.graph.figure
174-
self._fig = fig = Figure(data=[image_trace, scatter_trace])
230+
self._fig = fig = Figure(data=[])
175231
fig.update_layout(
176232
template=None,
177233
margin=dict(l=0, r=0, b=0, t=0, pad=4),
@@ -212,15 +268,19 @@ def _create_dash_components(self):
212268
self._position = Store(id=self._subid("position", True), data=0)
213269
self._requested_index = Store(id=self._subid("req-index"), data=0)
214270
self._request_data = Store(id=self._subid("req-data"), data="")
215-
self._lowres_data = Store(id=self._subid("lowres-data"), data=thumbnails)
216-
self._indicators = Store(id=self._subid("indicators"), data=[])
271+
self._lowres_data = Store(id=self._subid("lowres"), data=thumbnails)
272+
self._overlay_data = Store(id=self._subid("overlay"), data=[])
273+
self._img_traces = Store(id=self._subid("img-traces"), data=[])
274+
self._indicator_traces = Store(id=self._subid("indicator-traces"), data=[])
217275
self._stores = [
218276
self._info,
219277
self._position,
220278
self._requested_index,
221279
self._request_data,
222280
self._lowres_data,
223-
self._indicators,
281+
self._overlay_data,
282+
self._img_traces,
283+
self._indicator_traces,
224284
]
225285

226286
def _create_server_callbacks(self):
@@ -232,13 +292,16 @@ def _create_server_callbacks(self):
232292
[Input(self._requested_index.id, "data")],
233293
)
234294
def upload_requested_slice(slice_index):
235-
slice = self._slice(slice_index)
236-
return [slice_index, img_array_to_uri(slice)]
295+
slice = img_array_to_uri(self._slice(slice_index))
296+
return {"index": slice_index, "slice": slice}
237297

238298
def _create_client_callbacks(self):
239299
"""Create the callbacks that run client-side."""
240300
app = self._app
241301

302+
# ----------------------------------------------------------------------
303+
# Callback to update position (in scene coordinates) from the index.
304+
242305
app.clientside_callback(
243306
"""
244307
function update_position(index, info) {
@@ -250,15 +313,25 @@ def _create_client_callbacks(self):
250313
[State(self._info.id, "data")],
251314
)
252315

316+
# ----------------------------------------------------------------------
317+
# Callback to request new slices.
318+
# Note: this callback cannot be merged with the one below, because
319+
# it would create a circular dependency.
320+
253321
app.clientside_callback(
254322
"""
255-
function handle_slice_index(index) {
323+
function update_request(index) {
324+
325+
// Clear the cache?
256326
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
257327
let slice_cache = window.slicecache_for_{{ID}};
328+
329+
// Request a new slice (or not)
330+
let request_index = index;
258331
if (slice_cache[index]) {
259332
return window.dash_clientside.no_update;
260333
} else {
261-
console.log('requesting slice ' + index)
334+
console.log('request slice ' + index);
262335
return index;
263336
}
264337
}
@@ -269,60 +342,80 @@ def _create_client_callbacks(self):
269342
[Input(self.slider.id, "value")],
270343
)
271344

345+
# ----------------------------------------------------------------------
346+
# Callback that creates a list of image traces (slice and overlay).
347+
272348
app.clientside_callback(
273349
"""
274-
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
275-
let new_index = index_and_data[0];
276-
let new_data = index_and_data[1];
277-
// Store data in cache
350+
function update_image_traces(index, req_data, overlays, lowres, info, current_traces) {
351+
352+
// Add data to the cache if the data is indeed new
278353
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
279354
let slice_cache = window.slicecache_for_{{ID}};
280-
slice_cache[new_index] = new_data;
281-
// Get the data we need *now*
282-
let data = slice_cache[index];
283-
let x0 = info.origin[0], y0 = info.origin[1];
284-
let dx = info.spacing[0], dy = info.spacing[1];
285-
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
286-
// Maybe we do not need an update
287-
if (!data) {
288-
data = lowres[index];
355+
for (let trigger of dash_clientside.callback_context.triggered) {
356+
if (trigger.prop_id.indexOf('req-data') >= 0) {
357+
slice_cache[req_data.index] = req_data;
358+
break;
359+
}
360+
}
361+
362+
// Prepare traces
363+
let slice_trace = {
364+
type: 'image',
365+
x0: info.origin[0],
366+
y0: info.origin[1],
367+
dx: info.spacing[0],
368+
dy: info.spacing[1],
369+
hovertemplate: '(%{x:.2f}, %{y:.2f})<extra></extra>'
370+
};
371+
let overlay_trace = {...slice_trace};
372+
overlay_trace.hoverinfo = 'skip';
373+
overlay_trace.source = overlays[index] || '';
374+
overlay_trace.hovertemplate = '';
375+
let new_traces = [slice_trace, overlay_trace];
376+
377+
// Depending on the state of the cache, use full data, or use lowres and request slice
378+
if (slice_cache[index]) {
379+
let cached = slice_cache[index];
380+
slice_trace.source = cached.slice;
381+
} else {
382+
slice_trace.source = lowres[index];
289383
// Scale the image to take the exact same space as the full-res
290384
// version. It's not correct, but it looks better ...
291-
dx *= info.size[0] / info.lowres_size[0];
292-
dy *= info.size[1] / info.lowres_size[1];
293-
x0 += 0.5 * dx - 0.5 * info.spacing[0];
294-
y0 += 0.5 * dy - 0.5 * info.spacing[1];
385+
slice_trace.dx *= info.size[0] / info.lowres_size[0];
386+
slice_trace.dy *= info.size[1] / info.lowres_size[1];
387+
slice_trace.x0 += 0.5 * slice_trace.dx - 0.5 * info.spacing[0];
388+
slice_trace.y0 += 0.5 * slice_trace.dy - 0.5 * info.spacing[1];
295389
}
296-
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
297-
return window.dash_clientside.no_update;
390+
391+
// Has the image data even changed?
392+
if (!current_traces.length) { current_traces = [{source:''}, {source:''}]; }
393+
if (new_traces[0].source == current_traces[0].source &&
394+
new_traces[1].source == current_traces[1].source)
395+
{
396+
new_traces = window.dash_clientside.no_update;
298397
}
299-
// Otherwise, perform update
300-
console.log("updating figure");
301-
let figure = {...ori_figure};
302-
figure.data[0].source = data;
303-
figure.data[0].x0 = x0;
304-
figure.data[0].y0 = y0;
305-
figure.data[0].dx = dx;
306-
figure.data[0].dy = dy;
307-
figure.data[1] = indicators;
308-
return figure;
398+
return new_traces;
309399
}
310400
""".replace(
311401
"{{ID}}", self._context_id
312402
),
313-
Output(self.graph.id, "figure"),
403+
Output(self._img_traces.id, "data"),
314404
[
315405
Input(self.slider.id, "value"),
316406
Input(self._request_data.id, "data"),
317-
Input(self._indicators.id, "data"),
407+
Input(self._overlay_data.id, "data"),
318408
],
319409
[
320-
State(self.graph.id, "figure"),
321410
State(self._lowres_data.id, "data"),
322411
State(self._info.id, "data"),
412+
State(self._img_traces.id, "data"),
323413
],
324414
)
325415

416+
# ----------------------------------------------------------------------
417+
# Callback to create scatter traces from the positions of other slicers.
418+
326419
# Select the *other* axii
327420
axii = [0, 1, 2]
328421
axii.pop(self._axis)
@@ -349,18 +442,18 @@ def _create_client_callbacks(self):
349442
x.push(...[pos, pos, pos, pos, pos, pos]);
350443
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
351444
}
352-
return {
445+
return [{
353446
type: 'scatter',
354447
mode: 'lines',
355448
line: {color: '#ff00aa'},
356449
x: x,
357450
y: y,
358451
hoverinfo: 'skip',
359452
version: version
360-
};
453+
}];
361454
}
362455
""",
363-
Output(self._indicators.id, "data"),
456+
Output(self._indicator_traces.id, "data"),
364457
[
365458
Input(
366459
{
@@ -375,6 +468,35 @@ def _create_client_callbacks(self):
375468
],
376469
[
377470
State(self._info.id, "data"),
378-
State(self._indicators.id, "data"),
471+
State(self._indicator_traces.id, "data"),
472+
],
473+
)
474+
475+
# ----------------------------------------------------------------------
476+
# Callback that composes a figure from multiple trace sources.
477+
478+
app.clientside_callback(
479+
"""
480+
function update_figure(img_traces, indicators, ori_figure) {
481+
482+
// Collect traces
483+
let traces = [];
484+
for (let trace of img_traces) { traces.push(trace); }
485+
for (let trace of indicators) { traces.push(trace); }
486+
487+
// Update figure
488+
console.log("updating figure");
489+
let figure = {...ori_figure};
490+
figure.data = traces;
491+
return figure;
492+
}
493+
""",
494+
Output(self.graph.id, "figure"),
495+
[
496+
Input(self._img_traces.id, "data"),
497+
Input(self._indicator_traces.id, "data"),
498+
],
499+
[
500+
State(self.graph.id, "figure"),
379501
],
380502
)

examples/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)