Skip to content

Commit d236bc2

Browse files
committed
animations work for grayscale images, with or without binary string
1 parent 8be8ca0 commit d236bc2

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import plotly.graph_objs as go
22
from _plotly_utils.basevalidators import ColorscaleValidator
3-
from ._core import apply_default_cascade, init_figure
3+
from ._core import apply_default_cascade, init_figure, configure_animation_controls
44
from io import BytesIO
55
import base64
66
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
@@ -133,7 +133,7 @@ def imshow(
133133
labels={},
134134
x=None,
135135
y=None,
136-
animation_frame=False,
136+
animation_frame=None,
137137
facet_col=None,
138138
facet_col_wrap=None,
139139
color_continuous_scale=None,
@@ -353,13 +353,21 @@ def imshow(
353353

354354
# --------------- Starting from here img is always a numpy array --------
355355
img = np.asanyarray(img)
356+
slice_through = False
356357
if facet_col is not None:
357358
img = np.moveaxis(img, facet_col, 0)
358359
facet_col = True
359-
360+
slice_through = True
361+
if animation_frame is not None:
362+
img = np.moveaxis(img, animation_frame, 0)
363+
animation_frame = True
364+
args["animation_frame"] = "plane"
365+
slice_through = True
366+
367+
print("slice_through", slice_through)
360368
# Default behaviour of binary_string: True for RGB images, False for 2D
361369
if binary_string is None:
362-
if facet_col:
370+
if slice_through:
363371
binary_string = img.ndim >= 4 and not is_dataframe
364372
else:
365373
binary_string = img.ndim >= 3 and not is_dataframe
@@ -391,7 +399,7 @@ def imshow(
391399
zmin = 0
392400

393401
# For 2d data, use Heatmap trace, unless binary_string is True
394-
if (img.ndim == 2 or (img.ndim == 3 and facet_col)) and not binary_string:
402+
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string:
395403
if y is not None and img.shape[0] != len(y):
396404
raise ValueError(
397405
"The length of the y vector must match the length of the first "
@@ -402,10 +410,10 @@ def imshow(
402410
"The length of the x vector must match the length of the second "
403411
+ "dimension of the img matrix."
404412
)
405-
if facet_col:
413+
if slice_through:
406414
traces = [
407-
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1")
408-
for img_slice in img
415+
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i))
416+
for i, img_slice in enumerate(img)
409417
]
410418
else:
411419
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")]
@@ -429,7 +437,7 @@ def imshow(
429437
# For 2D+RGB data, use Image trace
430438
elif (
431439
img.ndim == 3
432-
and (img.shape[-1] in [3, 4] or (facet_col and binary_string))
440+
and (img.shape[-1] in [3, 4] or (slice_through and binary_string))
433441
or (img.ndim == 2 and binary_string)
434442
):
435443
rescale_image = True # to check whether image has been modified
@@ -442,7 +450,7 @@ def imshow(
442450
if zmin is None and zmax is None: # no rescaling, faster
443451
img_rescaled = img
444452
rescale_image = False
445-
elif img.ndim == 2 or (img.ndim == 3 and facet_col):
453+
elif img.ndim == 2 or (img.ndim == 3 and slice_through):
446454
img_rescaled = rescale_intensity(
447455
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
448456
)
@@ -457,7 +465,7 @@ def imshow(
457465
for ch in range(img.shape[-1])
458466
]
459467
)
460-
if facet_col:
468+
if slice_through:
461469
img_str = [
462470
_array_to_b64str(
463471
img_rescaled_slice,
@@ -477,7 +485,7 @@ def imshow(
477485
ext=binary_format,
478486
)
479487
]
480-
traces = [go.Image(source=img_str_slice) for img_str_slice in img_str]
488+
traces = [go.Image(source=img_str_slice, name=str(i)) for i, img_str_slice in enumerate(img_str)]
481489
else:
482490
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
483491
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
@@ -498,8 +506,15 @@ def imshow(
498506
layout_patch["title_text"] = args["title"]
499507
elif args["template"].layout.margin.t is None:
500508
layout_patch["margin"] = {"t": 60}
509+
510+
frame_list = []
501511
for index, trace in enumerate(traces):
502-
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
512+
if facet_col or index == 0:
513+
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
514+
if animation_frame:
515+
frame_list.append(dict(data=trace, layout=layout, name=str(index)))
516+
if animation_frame:
517+
fig.frames = frame_list
503518
fig.update_layout(layout)
504519
fig.update_layout(layout_patch)
505520
# Hover name, z or color
@@ -530,5 +545,6 @@ def imshow(
530545
fig.update_xaxes(title_text=labels["x"])
531546
if labels["y"]:
532547
fig.update_yaxes(title_text=labels["y"])
533-
fig.update_layout(template=args["template"], overwrite=True)
548+
configure_animation_controls(args, go.Image, fig)
549+
#fig.update_layout(template=args["template"], overwrite=True)
534550
return fig

0 commit comments

Comments
 (0)