Skip to content

Commit c8e852e

Browse files
committed
animations now work + tests
1 parent d236bc2 commit c8e852e

File tree

2 files changed

+48
-19
lines changed

2 files changed

+48
-19
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,6 @@ def imshow(
364364
args["animation_frame"] = "plane"
365365
slice_through = True
366366

367-
print("slice_through", slice_through)
368367
# Default behaviour of binary_string: True for RGB images, False for 2D
369368
if binary_string is None:
370369
if slice_through:
@@ -382,7 +381,11 @@ def imshow(
382381

383382
# -------- Contrast rescaling: either minmax or infer ------------------
384383
if contrast_rescaling is None:
385-
contrast_rescaling = "minmax" if img.ndim == 2 else "infer"
384+
contrast_rescaling = (
385+
"minmax"
386+
if (img.ndim == 2 or (img.ndim == 3 and slice_through))
387+
else "infer"
388+
)
386389

387390
# We try to set zmin and zmax only if necessary, because traces have good defaults
388391
if contrast_rescaling == "minmax":
@@ -436,10 +439,8 @@ def imshow(
436439

437440
# For 2D+RGB data, use Image trace
438441
elif (
439-
img.ndim == 3
440-
and (img.shape[-1] in [3, 4] or (slice_through and binary_string))
441-
or (img.ndim == 2 and binary_string)
442-
):
442+
img.ndim >= 3 and (img.shape[-1] in [3, 4] or slice_through and binary_string)
443+
) or (img.ndim == 2 and binary_string):
443444
rescale_image = True # to check whether image has been modified
444445
if zmin is not None and zmax is not None:
445446
zmin, zmax = (
@@ -455,15 +456,16 @@ def imshow(
455456
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
456457
)
457458
else:
458-
img_rescaled = np.dstack(
459+
img_rescaled = np.stack(
459460
[
460461
rescale_intensity(
461462
img[..., ch],
462463
in_range=(zmin[ch], zmax[ch]),
463464
out_range=np.uint8,
464465
)
465466
for ch in range(img.shape[-1])
466-
]
467+
],
468+
axis=-1,
467469
)
468470
if slice_through:
469471
img_str = [
@@ -485,10 +487,19 @@ def imshow(
485487
ext=binary_format,
486488
)
487489
]
488-
traces = [go.Image(source=img_str_slice, name=str(i)) for i, img_str_slice in enumerate(img_str)]
490+
traces = [
491+
go.Image(source=img_str_slice, name=str(i))
492+
for i, img_str_slice in enumerate(img_str)
493+
]
489494
else:
490495
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
491-
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
496+
if slice_through:
497+
traces = [
498+
go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel)
499+
for img_slice in img
500+
]
501+
else:
502+
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
492503
layout = {}
493504
if origin == "lower":
494505
layout["yaxis"] = dict(autorange=True)
@@ -546,5 +557,5 @@ def imshow(
546557
if labels["y"]:
547558
fig.update_yaxes(title_text=labels["y"])
548559
configure_animation_controls(args, go.Image, fig)
549-
#fig.update_layout(template=args["template"], overwrite=True)
560+
# fig.update_layout(template=args["template"], overwrite=True)
550561
return fig

packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,29 @@ def test_facet_col(facet_col, binary_string):
327327
facet_col_wrap=facet_col_wrap,
328328
binary_string=binary_string,
329329
)
330-
if facet_col is not None:
331-
nslices = img.shape[facet_col]
332-
ncols = int(facet_col_wrap)
333-
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
334-
nmax = ncols * nrows
335-
assert "yaxis%d" % nmax in fig.layout
336-
assert "yaxis%d" % (nmax + 1) not in fig.layout
337-
assert len(fig.data) == nslices
330+
nslices = img.shape[facet_col]
331+
ncols = int(facet_col_wrap)
332+
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
333+
nmax = ncols * nrows
334+
assert "yaxis%d" % nmax in fig.layout
335+
assert "yaxis%d" % (nmax + 1) not in fig.layout
336+
assert len(fig.data) == nslices
337+
338+
339+
@pytest.mark.parametrize("animation_frame", [0, 1, 2, -1])
340+
@pytest.mark.parametrize("binary_string", [False, True])
341+
def test_animation_frame_grayscale(animation_frame, binary_string):
342+
img = np.random.randint(255, size=(10, 9, 8)).astype(np.uint8)
343+
fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,)
344+
nslices = img.shape[animation_frame]
345+
assert len(fig.frames) == nslices
346+
347+
348+
@pytest.mark.parametrize("animation_frame", [0, 1, 2])
349+
@pytest.mark.parametrize("binary_string", [False, True])
350+
def test_animation_frame_rgb(animation_frame, binary_string):
351+
img = np.random.randint(255, size=(10, 9, 8, 3)).astype(np.uint8)
352+
fig = px.imshow(img, animation_frame=animation_frame, binary_string=binary_string,)
353+
print(binary_string)
354+
nslices = img.shape[animation_frame]
355+
assert len(fig.frames) == nslices

0 commit comments

Comments
 (0)