1
1
import numpy as np
2
- from plotly .graph_objects import Figure , Image , Scatter
2
+ from plotly .graph_objects import Figure
3
3
from dash import Dash
4
4
from dash .dependencies import Input , Output , State , ALL
5
5
from dash_core_components import Graph , Slider , Store
@@ -51,7 +51,7 @@ def __init__(
51
51
origin = None ,
52
52
axis = 0 ,
53
53
reverse_y = True ,
54
- scene_id = None
54
+ scene_id = None ,
55
55
):
56
56
57
57
if not isinstance (app , Dash ):
@@ -112,6 +112,11 @@ def axis(self):
112
112
"""The axis at which the slicer is slicing."""
113
113
return self ._axis
114
114
115
+ @property
116
+ def nslices (self ):
117
+ """The number of slices for this slicer."""
118
+ return self ._volume .shape [self ._axis ]
119
+
115
120
@property
116
121
def graph (self ):
117
122
"""The dcc.Graph for this slicer."""
@@ -124,11 +129,69 @@ def slider(self):
124
129
125
130
@property
126
131
def stores (self ):
127
- """A list of dcc.Stores that the slicer needs to work. These must
128
- be added to the app layout.
132
+ """A list of dcc.Store objects that the slicer needs to work.
133
+ These must be added to the app layout.
129
134
"""
130
135
return self ._stores
131
136
137
+ @property
138
+ def overlay_data (self ):
139
+ """A dcc.Store containing the overlay data. The form of this
140
+ data is considered an implementation detail; users are expected to use
141
+ ``create_overlay_data`` to create it.
142
+ """
143
+ return self ._overlay_data
144
+
145
+ def create_overlay_data (self , mask , color = (0 , 255 , 255 , 100 )):
146
+ """Given a 3D mask array and an index, create an object that
147
+ can be used as output for ``slicer.overlay_data``.
148
+ """
149
+ # Check the mask
150
+ if mask .dtype not in (np .bool , np .uint8 ):
151
+ raise ValueError (f"Mask must have bool or uint8 dtype, not { mask .dtype } ." )
152
+ if mask .shape != self ._volume .shape :
153
+ raise ValueError (
154
+ f"Overlay must has shape { mask .shape } , but expected { self ._volume .shape } "
155
+ )
156
+ mask = mask .astype (np .uint8 , copy = False ) # need int to index
157
+
158
+ # Create a colormap (list) from the given color(s)
159
+ # todo: also support hex colors and css color names
160
+ color = np .array (color , np .uint8 )
161
+ if color .ndim == 1 :
162
+ if color .shape [0 ] != 4 :
163
+ raise ValueError ("Overlay color must be 4 ints (0..255)." )
164
+ colormap = [(0 , 0 , 0 , 0 ), tuple (color )]
165
+ elif color .ndim == 2 :
166
+ if color .shape [1 ] != 4 :
167
+ raise ValueError ("Overlay colors must be 4 ints (0..255)." )
168
+ colormap = [tuple (x ) for x in color ]
169
+ else :
170
+ raise ValueError (
171
+ "Overlay color must be a single color or a list of colors."
172
+ )
173
+
174
+ # Produce slices (base64 png strings)
175
+ overlay_slices = []
176
+ for index in range (self .nslices ):
177
+ # Sample the slice
178
+ indices = [slice (None ), slice (None ), slice (None )]
179
+ indices [self ._axis ] = index
180
+ im = mask [tuple (indices )]
181
+ max_mask = im .max ()
182
+ if max_mask == 0 :
183
+ # If the mask is all zeros, we can simply not draw it
184
+ overlay_slices .append (None )
185
+ else :
186
+ # Turn into rgba
187
+ while len (colormap ) <= max_mask :
188
+ colormap .append (colormap [- 1 ])
189
+ colormap_arr = np .array (colormap )
190
+ rgba = colormap_arr [im ]
191
+ overlay_slices .append (img_array_to_uri (rgba ))
192
+
193
+ return overlay_slices
194
+
132
195
def _subid (self , name , use_dict = False ):
133
196
"""Given a name, get the full id including the context id prefix."""
134
197
if use_dict :
@@ -163,15 +226,8 @@ def _create_dash_components(self):
163
226
]
164
227
info ["lowres_size" ] = thumbnail_size
165
228
166
- # Create traces
167
- # todo: can add "%{z[0]}", but that would be the scaled value ...
168
- image_trace = Image (
169
- source = "" , dx = 1 , dy = 1 , hovertemplate = "(%{x}, %{y})<extra></extra>"
170
- )
171
- scatter_trace = Scatter (x = [], y = []) # placeholder
172
-
173
229
# Create the figure object - can be accessed by user via slicer.graph.figure
174
- self ._fig = fig = Figure (data = [image_trace , scatter_trace ])
230
+ self ._fig = fig = Figure (data = [])
175
231
fig .update_layout (
176
232
template = None ,
177
233
margin = dict (l = 0 , r = 0 , b = 0 , t = 0 , pad = 4 ),
@@ -212,15 +268,19 @@ def _create_dash_components(self):
212
268
self ._position = Store (id = self ._subid ("position" , True ), data = 0 )
213
269
self ._requested_index = Store (id = self ._subid ("req-index" ), data = 0 )
214
270
self ._request_data = Store (id = self ._subid ("req-data" ), data = "" )
215
- self ._lowres_data = Store (id = self ._subid ("lowres-data" ), data = thumbnails )
216
- self ._indicators = Store (id = self ._subid ("indicators" ), data = [])
271
+ self ._lowres_data = Store (id = self ._subid ("lowres" ), data = thumbnails )
272
+ self ._overlay_data = Store (id = self ._subid ("overlay" ), data = [])
273
+ self ._img_traces = Store (id = self ._subid ("img-traces" ), data = [])
274
+ self ._indicator_traces = Store (id = self ._subid ("indicator-traces" ), data = [])
217
275
self ._stores = [
218
276
self ._info ,
219
277
self ._position ,
220
278
self ._requested_index ,
221
279
self ._request_data ,
222
280
self ._lowres_data ,
223
- self ._indicators ,
281
+ self ._overlay_data ,
282
+ self ._img_traces ,
283
+ self ._indicator_traces ,
224
284
]
225
285
226
286
def _create_server_callbacks (self ):
@@ -232,13 +292,16 @@ def _create_server_callbacks(self):
232
292
[Input (self ._requested_index .id , "data" )],
233
293
)
234
294
def upload_requested_slice (slice_index ):
235
- slice = self ._slice (slice_index )
236
- return [ slice_index , img_array_to_uri ( slice )]
295
+ slice = img_array_to_uri ( self ._slice (slice_index ) )
296
+ return { "index" : slice_index , " slice" : slice }
237
297
238
298
def _create_client_callbacks (self ):
239
299
"""Create the callbacks that run client-side."""
240
300
app = self ._app
241
301
302
+ # ----------------------------------------------------------------------
303
+ # Callback to update position (in scene coordinates) from the index.
304
+
242
305
app .clientside_callback (
243
306
"""
244
307
function update_position(index, info) {
@@ -250,15 +313,25 @@ def _create_client_callbacks(self):
250
313
[State (self ._info .id , "data" )],
251
314
)
252
315
316
+ # ----------------------------------------------------------------------
317
+ # Callback to request new slices.
318
+ # Note: this callback cannot be merged with the one below, because
319
+ # it would create a circular dependency.
320
+
253
321
app .clientside_callback (
254
322
"""
255
- function handle_slice_index(index) {
323
+ function update_request(index) {
324
+
325
+ // Clear the cache?
256
326
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
257
327
let slice_cache = window.slicecache_for_{{ID}};
328
+
329
+ // Request a new slice (or not)
330
+ let request_index = index;
258
331
if (slice_cache[index]) {
259
332
return window.dash_clientside.no_update;
260
333
} else {
261
- console.log('requesting slice ' + index)
334
+ console.log('request slice ' + index);
262
335
return index;
263
336
}
264
337
}
@@ -269,60 +342,80 @@ def _create_client_callbacks(self):
269
342
[Input (self .slider .id , "value" )],
270
343
)
271
344
345
+ # ----------------------------------------------------------------------
346
+ # Callback that creates a list of image traces (slice and overlay).
347
+
272
348
app .clientside_callback (
273
349
"""
274
- function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, info) {
275
- let new_index = index_and_data[0];
276
- let new_data = index_and_data[1];
277
- // Store data in cache
350
+ function update_image_traces(index, req_data, overlays, lowres, info, current_traces) {
351
+
352
+ // Add data to the cache if the data is indeed new
278
353
if (!window.slicecache_for_{{ID}}) { window.slicecache_for_{{ID}} = {}; }
279
354
let slice_cache = window.slicecache_for_{{ID}};
280
- slice_cache[new_index] = new_data;
281
- // Get the data we need *now*
282
- let data = slice_cache[index];
283
- let x0 = info.origin[0], y0 = info.origin[1];
284
- let dx = info.spacing[0], dy = info.spacing[1];
285
- //slice_cache[new_index] = undefined; // todo: disabled cache for now!
286
- // Maybe we do not need an update
287
- if (!data) {
288
- data = lowres[index];
355
+ for (let trigger of dash_clientside.callback_context.triggered) {
356
+ if (trigger.prop_id.indexOf('req-data') >= 0) {
357
+ slice_cache[req_data.index] = req_data;
358
+ break;
359
+ }
360
+ }
361
+
362
+ // Prepare traces
363
+ let slice_trace = {
364
+ type: 'image',
365
+ x0: info.origin[0],
366
+ y0: info.origin[1],
367
+ dx: info.spacing[0],
368
+ dy: info.spacing[1],
369
+ hovertemplate: '(%{x:.2f}, %{y:.2f})<extra></extra>'
370
+ };
371
+ let overlay_trace = {...slice_trace};
372
+ overlay_trace.hoverinfo = 'skip';
373
+ overlay_trace.source = overlays[index] || '';
374
+ overlay_trace.hovertemplate = '';
375
+ let new_traces = [slice_trace, overlay_trace];
376
+
377
+ // Depending on the state of the cache, use full data, or use lowres and request slice
378
+ if (slice_cache[index]) {
379
+ let cached = slice_cache[index];
380
+ slice_trace.source = cached.slice;
381
+ } else {
382
+ slice_trace.source = lowres[index];
289
383
// Scale the image to take the exact same space as the full-res
290
384
// version. It's not correct, but it looks better ...
291
- dx *= info.size[0] / info.lowres_size[0];
292
- dy *= info.size[1] / info.lowres_size[1];
293
- x0 += 0.5 * dx - 0.5 * info.spacing[0];
294
- y0 += 0.5 * dy - 0.5 * info.spacing[1];
385
+ slice_trace. dx *= info.size[0] / info.lowres_size[0];
386
+ slice_trace. dy *= info.size[1] / info.lowres_size[1];
387
+ slice_trace. x0 += 0.5 * slice_trace. dx - 0.5 * info.spacing[0];
388
+ slice_trace. y0 += 0.5 * slice_trace. dy - 0.5 * info.spacing[1];
295
389
}
296
- if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
297
- return window.dash_clientside.no_update;
390
+
391
+ // Has the image data even changed?
392
+ if (!current_traces.length) { current_traces = [{source:''}, {source:''}]; }
393
+ if (new_traces[0].source == current_traces[0].source &&
394
+ new_traces[1].source == current_traces[1].source)
395
+ {
396
+ new_traces = window.dash_clientside.no_update;
298
397
}
299
- // Otherwise, perform update
300
- console.log("updating figure");
301
- let figure = {...ori_figure};
302
- figure.data[0].source = data;
303
- figure.data[0].x0 = x0;
304
- figure.data[0].y0 = y0;
305
- figure.data[0].dx = dx;
306
- figure.data[0].dy = dy;
307
- figure.data[1] = indicators;
308
- return figure;
398
+ return new_traces;
309
399
}
310
400
""" .replace (
311
401
"{{ID}}" , self ._context_id
312
402
),
313
- Output (self .graph .id , "figure " ),
403
+ Output (self ._img_traces .id , "data " ),
314
404
[
315
405
Input (self .slider .id , "value" ),
316
406
Input (self ._request_data .id , "data" ),
317
- Input (self ._indicators .id , "data" ),
407
+ Input (self ._overlay_data .id , "data" ),
318
408
],
319
409
[
320
- State (self .graph .id , "figure" ),
321
410
State (self ._lowres_data .id , "data" ),
322
411
State (self ._info .id , "data" ),
412
+ State (self ._img_traces .id , "data" ),
323
413
],
324
414
)
325
415
416
+ # ----------------------------------------------------------------------
417
+ # Callback to create scatter traces from the positions of other slicers.
418
+
326
419
# Select the *other* axii
327
420
axii = [0 , 1 , 2 ]
328
421
axii .pop (self ._axis )
@@ -349,18 +442,18 @@ def _create_client_callbacks(self):
349
442
x.push(...[pos, pos, pos, pos, pos, pos]);
350
443
y.push(...[y0 - d, y0, null, y1, y1 + d, null]);
351
444
}
352
- return {
445
+ return [ {
353
446
type: 'scatter',
354
447
mode: 'lines',
355
448
line: {color: '#ff00aa'},
356
449
x: x,
357
450
y: y,
358
451
hoverinfo: 'skip',
359
452
version: version
360
- };
453
+ }] ;
361
454
}
362
455
""" ,
363
- Output (self ._indicators .id , "data" ),
456
+ Output (self ._indicator_traces .id , "data" ),
364
457
[
365
458
Input (
366
459
{
@@ -375,6 +468,35 @@ def _create_client_callbacks(self):
375
468
],
376
469
[
377
470
State (self ._info .id , "data" ),
378
- State (self ._indicators .id , "data" ),
471
+ State (self ._indicator_traces .id , "data" ),
472
+ ],
473
+ )
474
+
475
+ # ----------------------------------------------------------------------
476
+ # Callback that composes a figure from multiple trace sources.
477
+
478
+ app .clientside_callback (
479
+ """
480
+ function update_figure(img_traces, indicators, ori_figure) {
481
+
482
+ // Collect traces
483
+ let traces = [];
484
+ for (let trace of img_traces) { traces.push(trace); }
485
+ for (let trace of indicators) { traces.push(trace); }
486
+
487
+ // Update figure
488
+ console.log("updating figure");
489
+ let figure = {...ori_figure};
490
+ figure.data = traces;
491
+ return figure;
492
+ }
493
+ """ ,
494
+ Output (self .graph .id , "figure" ),
495
+ [
496
+ Input (self ._img_traces .id , "data" ),
497
+ Input (self ._indicator_traces .id , "data" ),
498
+ ],
499
+ [
500
+ State (self .graph .id , "figure" ),
379
501
],
380
502
)
0 commit comments