Skip to content

Commit db3b221

Browse files
authored
A few fixes and performance tweaks (#37)
* fix detecting layout changes, and use different timeout for rate-limiting depending on source * Fix that initialization of figures was sometimes weird * produce more reasonable sized thumbnails for elongated data * rename lowres -> thumbnail
1 parent 3696fc0 commit db3b221

File tree

3 files changed

+70
-51
lines changed

3 files changed

+70
-51
lines changed

dash_slicer/slicer.py

Lines changed: 49 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class VolumeSlicer:
3535
color (str): the color for this slicer. By default the color is
3636
red, green, or blue, depending on the axis. Set to empty string
3737
for "no color".
38-
thumbnail (int or bool): linear size of low-resolution data to be
38+
thumbnail (int or bool): preferred size of low-resolution data to be
3939
uploaded to the client. If ``False``, the full-resolution data are
4040
uploaded client-side. If ``True`` (default), a default value of 32 is
4141
used.
@@ -108,12 +108,18 @@ def __init__(
108108
# Check and store thumbnail
109109
if not (isinstance(thumbnail, (int, bool))):
110110
raise ValueError("thumbnail must be a boolean or an integer.")
111-
# No thumbnail if thumbnail size is larger than image size
112-
if isinstance(thumbnail, int) and thumbnail > np.max(volume.shape):
113-
thumbnail = False
114-
if thumbnail is True:
115-
thumbnail = 32 # default size
116-
self._thumbnail = thumbnail
111+
if thumbnail is False:
112+
self._thumbnail = False
113+
elif thumbnail is None or thumbnail is True:
114+
self._thumbnail = 32 # default size
115+
else:
116+
thumbnail = int(thumbnail)
117+
if thumbnail >= np.max(volume.shape[:3]):
118+
self._thumbnail = False # dont go larger than image size
119+
elif thumbnail <= 0:
120+
self._thumbnail = False # consider 0 and -1 the same as False
121+
else:
122+
self._thumbnail = thumbnail
117123

118124
# Check and store scene id, and generate
119125
if scene_id is None:
@@ -299,15 +305,16 @@ def _create_dash_components(self):
299305
"""Create the graph, slider, figure, etc."""
300306
info = self._slice_info
301307

302-
# Prep low-res slices
303-
if self._thumbnail is False:
308+
# Prep low-res slices. The get_thumbnail_size() is a bit like
309+
# a simulation to get the low-res size.
310+
if not self._thumbnail:
304311
thumbnail_size = None
305-
info["lowres_size"] = info["size"]
312+
info["thumbnail_size"] = info["size"]
306313
else:
307-
thumbnail_size = get_thumbnail_size(
308-
info["size"][:2], (self._thumbnail, self._thumbnail)
314+
thumbnail_size = self._thumbnail
315+
info["thumbnail_size"] = get_thumbnail_size(
316+
info["size"][:2], thumbnail_size
309317
)
310-
info["lowres_size"] = thumbnail_size
311318
thumbnails = [
312319
img_array_to_uri(self._slice(i), thumbnail_size)
313320
for i in range(info["size"][2])
@@ -361,8 +368,8 @@ def _create_dash_components(self):
361368
# A dict of static info for this slicer
362369
self._info = Store(id=self._subid("info"), data=info)
363370

364-
# A list of low-res slices (encoded as base64-png)
365-
self._lowres_data = Store(id=self._subid("lowres"), data=thumbnails)
371+
# A list of low-res slices, or the full-res data (encoded as base64-png)
372+
self._thumbs_data = Store(id=self._subid("thumbs"), data=thumbnails)
366373

367374
# A list of mask slices (encoded as base64-png or null)
368375
self._overlay_data = Store(id=self._subid("overlay"), data=[])
@@ -389,7 +396,7 @@ def _create_dash_components(self):
389396

390397
self._stores = [
391398
self._info,
392-
self._lowres_data,
399+
self._thumbs_data,
393400
self._overlay_data,
394401
self._server_data,
395402
self._img_traces,
@@ -490,16 +497,20 @@ def _create_client_callbacks(self):
490497

491498
app.clientside_callback(
492499
"""
493-
function update_index_rate_limiting(index, relayoutData, n_intervals, interval, info, figure) {
500+
function update_index_rate_limiting(index, relayoutData, n_intervals, info, figure) {
494501
495502
if (!window._slicer_{{ID}}) window._slicer_{{ID}} = {};
496503
let private_state = window._slicer_{{ID}};
497504
let now = window.performance.now();
498505
499506
// Get whether the slider was moved
500-
let slider_was_moved = false;
507+
let slider_value_changed = false;
508+
let graph_layout_changed = false;
509+
let timer_ticked = false;
501510
for (let trigger of dash_clientside.callback_context.triggered) {
502-
if (trigger.prop_id.indexOf('slider') >= 0) slider_was_moved = true;
511+
if (trigger.prop_id.indexOf('slider') >= 0) slider_value_changed = true;
512+
if (trigger.prop_id.indexOf('graph') >= 0) graph_layout_changed = true;
513+
if (trigger.prop_id.indexOf('timer') >= 0) timer_ticked = true;
503514
}
504515
505516
// Calculate view range based on the volume
@@ -513,17 +524,8 @@ def _create_client_callbacks(self):
513524
];
514525
515526
// Get view range from the figure. We make range[0] < range[1]
516-
let range_was_changed = false;
517527
let xrangeFig = figure.layout.xaxis.range
518528
let yrangeFig = figure.layout.yaxis.range;
519-
if (relayoutData && relayoutData.xaxis && relayoutData.xaxis.range) {
520-
xrangeFig = relayoutData.xaxis.range;
521-
range_was_changed = true;
522-
}
523-
if (relayoutData && relayoutData.yaxis && relayoutData.yaxis.range) {
524-
yrangeFig = relayoutData.yaxis.range;
525-
range_was_changed = true
526-
}
527529
xrangeFig = [Math.min(xrangeFig[0], xrangeFig[1]), Math.max(xrangeFig[0], xrangeFig[1])];
528530
yrangeFig = [Math.min(yrangeFig[0], yrangeFig[1]), Math.max(yrangeFig[0], yrangeFig[1])];
529531
@@ -549,18 +551,25 @@ def _create_client_callbacks(self):
549551
// If the slider moved, remember the time when this happened
550552
private_state.new_time = private_state.new_time || 0;
551553
552-
if (slider_was_moved || range_was_changed) {
554+
555+
if (slider_value_changed) {
556+
private_state.new_time = now;
557+
private_state.timeout = 200;
558+
} else if (graph_layout_changed) {
553559
private_state.new_time = now;
560+
private_state.timeout = 400; // need longer timeout for smooth scroll zoom
554561
} else if (!n_intervals) {
555562
private_state.new_time = now;
563+
private_state.timeout = 100;
556564
}
557565
558-
// We can either update the rate-limited index interval ms after
559-
// the real index changed, or interval ms after it stopped
566+
// We can either update the rate-limited index timeout ms after
567+
// the real index changed, or timeout ms after it stopped
560568
// changing. The former makes the indicators come along while
561569
// dragging the slider, the latter is better for a smooth
562-
// experience, and the interval can be set much lower.
563-
if (now - private_state.new_time >= interval) {
570+
// experience, and the timeout can be set much lower.
571+
if (private_state.timeout && timer_ticked && now - private_state.new_time >= private_state.timeout) {
572+
private_state.timeout = 0;
564573
disable_timer = true;
565574
new_state = {
566575
index: index,
@@ -574,7 +583,6 @@ def _create_client_callbacks(self):
574583
if (index != private_state.index) {
575584
private_state.index = index;
576585
new_state.index_changed = true;
577-
console.log('requesting slice ' + index);
578586
}
579587
}
580588
@@ -593,7 +601,6 @@ def _create_client_callbacks(self):
593601
Input(self._timer.id, "n_intervals"),
594602
],
595603
[
596-
State(self._timer.id, "interval"),
597604
State(self._info.id, "data"),
598605
State(self._graph.id, "figure"),
599606
],
@@ -604,7 +611,7 @@ def _create_client_callbacks(self):
604611

605612
app.clientside_callback(
606613
"""
607-
function update_image_traces(index, server_data, overlays, lowres, info, current_traces) {
614+
function update_image_traces(index, server_data, overlays, thumbnails, info, current_traces) {
608615
609616
// Prepare traces
610617
let slice_trace = {
@@ -621,16 +628,16 @@ def _create_client_callbacks(self):
621628
overlay_trace.hovertemplate = '';
622629
let new_traces = [slice_trace, overlay_trace];
623630
624-
// Use full data, or use lowres
631+
// Use full data, or use thumbnails
625632
if (index == server_data.index) {
626633
slice_trace.source = server_data.slice;
627634
} else {
628-
slice_trace.source = lowres[index];
635+
slice_trace.source = thumbnails[index];
629636
// Scale the image to take the exact same space as the full-res
630637
// version. Note that depending on how the low-res data is
631638
// created, the pixel centers may not be correctly aligned.
632-
slice_trace.dx *= info.size[0] / info.lowres_size[0];
633-
slice_trace.dy *= info.size[1] / info.lowres_size[1];
639+
slice_trace.dx *= info.size[0] / info.thumbnail_size[0];
640+
slice_trace.dy *= info.size[1] / info.thumbnail_size[1];
634641
slice_trace.x0 += 0.5 * slice_trace.dx - 0.5 * info.stepsize[0];
635642
slice_trace.y0 += 0.5 * slice_trace.dy - 0.5 * info.stepsize[1];
636643
}
@@ -654,7 +661,7 @@ def _create_client_callbacks(self):
654661
Input(self._overlay_data.id, "data"),
655662
],
656663
[
657-
State(self._lowres_data.id, "data"),
664+
State(self._thumbs_data.id, "data"),
658665
State(self._info.id, "data"),
659666
State(self._img_traces.id, "data"),
660667
],
@@ -737,6 +744,7 @@ def _create_client_callbacks(self):
737744
State(self._info.id, "data"),
738745
State(self._state.id, "data"),
739746
],
747+
prevent_initial_call=True,
740748
)
741749

742750
# ----------------------------------------------------------------------

dash_slicer/utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,39 @@ def img_as_ubyte(img):
1818
return img.astype(np.uint8)
1919

2020

21-
def img_array_to_uri(img_array, new_size=None):
21+
def _thumbnail_size_from_scalar(image_size, ref_size):
22+
if image_size[0] > image_size[1]:
23+
return int(ref_size * image_size[0] / image_size[1]), ref_size
24+
else:
25+
return ref_size, int(ref_size * image_size[1] / image_size[0])
26+
27+
28+
def img_array_to_uri(img_array, ref_size=None):
2229
"""Convert the given image (numpy array) into a base64-encoded PNG."""
2330
img_array = img_as_ubyte(img_array)
2431
# todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency)
2532
# from plotly.express._imshow import _array_to_b64str
2633
# return _array_to_b64str(img_array)
2734
img_pil = PIL.Image.fromarray(img_array)
28-
if new_size:
29-
img_pil.thumbnail(new_size)
35+
if ref_size:
36+
size = img_array.shape[1], img_array.shape[0]
37+
img_pil.thumbnail(_thumbnail_size_from_scalar(size, ref_size))
3038
# The below was taken from plotly.utils.ImageUriValidator.pil_image_to_uri()
3139
f = io.BytesIO()
3240
img_pil.save(f, format="PNG")
3341
base64_str = base64.b64encode(f.getvalue()).decode()
3442
return "data:image/png;base64," + base64_str
3543

3644

37-
def get_thumbnail_size(size, new_size):
45+
def get_thumbnail_size(size, ref_size):
3846
"""Given an image size (w, h), and a preferred smaller size,
3947
get the actual size if we let Pillow downscale it.
4048
"""
49+
# Note that if you call thumbnail() to get the resulting size, then call
50+
# thumbnail() again with that size, the result may be yet another size.
4151
img_array = np.zeros(list(reversed(size)), np.uint8)
4252
img_pil = PIL.Image.fromarray(img_array)
43-
img_pil.thumbnail(new_size)
53+
img_pil.thumbnail(_thumbnail_size_from_scalar(size, ref_size))
4454
return img_pil.size
4555

4656

tests/test_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def test_img_array_to_uri():
3131
im = np.random.uniform(0, 255, (100, 100)).astype(np.uint8)
3232

3333
r1 = img_array_to_uri(im)
34-
r2 = img_array_to_uri(im, (32, 32))
35-
r3 = img_array_to_uri(im, (8, 8))
34+
r2 = img_array_to_uri(im, 32)
35+
r3 = img_array_to_uri(im, 8)
3636

3737
for r in (r1, r2, r3):
3838
assert isinstance(r, str)
@@ -43,9 +43,10 @@ def test_img_array_to_uri():
4343

4444
def test_get_thumbnail_size():
4545

46-
assert get_thumbnail_size((100, 100), (16, 16)) == (16, 16)
47-
assert get_thumbnail_size((50, 100), (16, 16)) == (8, 16)
48-
assert get_thumbnail_size((100, 100), (8, 16)) == (8, 8)
46+
assert get_thumbnail_size((100, 100), 16) == (16, 16)
47+
assert get_thumbnail_size((50, 100), 16) == (16, 32)
48+
assert get_thumbnail_size((100, 100), 8) == (8, 8)
49+
assert get_thumbnail_size((100, 50), 8) == (16, 8)
4950

5051

5152
def test_shape3d_to_size2d():

0 commit comments

Comments
 (0)