Skip to content

Commit b652039

Browse files
committed
animation + facet kinda working now, but it broke labels
1 parent a431fad commit b652039

File tree

1 file changed

+73
-29
lines changed

1 file changed

+73
-29
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pandas as pd
88
from .png import Writer, from_array
99
import numpy as np
10+
import itertools
1011

1112
try:
1213
import xarray
@@ -293,31 +294,41 @@ def imshow(
293294
args = locals()
294295
apply_default_cascade(args)
295296
labels = labels.copy()
296-
nslices = 1
297+
nslices_facet = 1
297298
if facet_col is not None:
298299
if isinstance(facet_col, str):
299300
facet_col = img.dims.index(facet_col)
300-
nslices = img.shape[facet_col]
301-
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices
302-
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
301+
nslices_facet = img.shape[facet_col]
302+
facet_slices = range(nslices_facet)
303+
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices_facet
304+
nrows = (
305+
nslices_facet // ncols + 1
306+
if nslices_facet % ncols
307+
else nslices_facet // ncols
308+
)
303309
else:
304310
nrows = 1
305311
ncols = 1
306312
if animation_frame is not None:
307313
if isinstance(animation_frame, str):
308314
animation_frame = img.dims.index(animation_frame)
309-
nslices = img.shape[animation_frame]
315+
nslices_animation = img.shape[animation_frame]
316+
animation_slices = range(nslices_animation)
310317
slice_through = (facet_col is not None) or (animation_frame is not None)
311-
slice_label = None
312-
slices = range(nslices)
318+
double_slice_through = (facet_col is not None) and (animation_frame is not None)
319+
facet_label = None
320+
animation_label = None
313321
# ----- Define x and y, set labels if img is an xarray -------------------
314322
if xarray_imported and isinstance(img, xarray.DataArray):
315323
dims = list(img.dims)
316-
if slice_through:
317-
slice_index = facet_col if facet_col is not None else animation_frame
318-
slices = img.coords[img.dims[slice_index]].values
319-
_ = dims.pop(slice_index)
320-
slice_label = img.dims[slice_index]
324+
if facet_col is not None:
325+
facet_slices = img.coords[img.dims[facet_col]].values
326+
_ = dims.pop(facet_col)
327+
facet_label = img.dims[facet_col]
328+
if animation_frame is not None:
329+
animation_slices = img.coords[img.dims[animation_frame]].values
330+
_ = dims.pop(animation_frame)
331+
animation_label = img.dims[animation_frame]
321332
y_label, x_label = dims[0], dims[1]
322333
# np.datetime64 is not handled correctly by go.Heatmap
323334
for ax in [x_label, y_label]:
@@ -333,8 +344,10 @@ def imshow(
333344
labels["x"] = x_label
334345
if labels.get("y", None) is None:
335346
labels["y"] = y_label
336-
if labels.get("slice", None) is None:
337-
labels["slice"] = slice_label
347+
if labels.get("animation_slice", None) is None:
348+
labels["animation_slice"] = animation_label
349+
if labels.get("facet_slice", None) is None:
350+
labels["facet_slice"] = facet_label
338351
if labels.get("color", None) is None:
339352
labels["color"] = xarray.plot.utils.label_from_attrs(img)
340353
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -371,11 +384,15 @@ def imshow(
371384
img = np.asanyarray(img)
372385
if facet_col is not None:
373386
img = np.moveaxis(img, facet_col, 0)
387+
print(img.shape)
388+
if animation_frame is not None and animation_frame < facet_col:
389+
animation_frame += 1
374390
facet_col = True
375391
if animation_frame is not None:
376392
img = np.moveaxis(img, animation_frame, 0)
393+
print(img.shape)
377394
animation_frame = True
378-
args["animation_frame"] = (
395+
args["animation_frame"] = ( # TODO
379396
"slice" if labels.get("slice") is None else labels["slice"]
380397
)
381398

@@ -431,9 +448,16 @@ def imshow(
431448
+ "dimension of the img matrix."
432449
)
433450
if slice_through:
451+
iterables = ()
452+
if animation_frame is not None:
453+
iterables += (range(nslices_animation),)
454+
if facet_col is not None:
455+
iterables += (range(nslices_facet),)
434456
traces = [
435-
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1", name=str(i))
436-
for i, img_slice in enumerate(img)
457+
go.Heatmap(
458+
x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)
459+
)
460+
for i, index_tup in enumerate(itertools.product(*iterables))
437461
]
438462
else:
439463
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")]
@@ -464,11 +488,21 @@ def imshow(
464488
_vectorize_zvalue(zmin, mode="min"),
465489
_vectorize_zvalue(zmax, mode="max"),
466490
)
491+
if slice_through:
492+
iterables = ()
493+
if animation_frame is not None:
494+
iterables += (range(nslices_animation),)
495+
if facet_col is not None:
496+
iterables += (range(nslices_facet),)
467497
if binary_string:
468498
if zmin is None and zmax is None: # no rescaling, faster
469499
img_rescaled = img
470500
rescale_image = False
471-
elif img.ndim == 2 or (img.ndim == 3 and slice_through):
501+
elif (
502+
img.ndim == 2
503+
or (img.ndim == 3 and slice_through)
504+
or (img.ndim == 4 and double_slice_through)
505+
):
472506
img_rescaled = rescale_intensity(
473507
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
474508
)
@@ -485,14 +519,15 @@ def imshow(
485519
axis=-1,
486520
)
487521
if slice_through:
522+
tuples = [index_tup for index_tup in itertools.product(*iterables)]
488523
img_str = [
489524
_array_to_b64str(
490-
img_rescaled_slice,
525+
img_rescaled[index_tup],
491526
backend=binary_backend,
492527
compression=binary_compression_level,
493528
ext=binary_format,
494529
)
495-
for img_rescaled_slice in img_rescaled
530+
for index_tup in itertools.product(*iterables)
496531
]
497532

498533
else:
@@ -512,8 +547,10 @@ def imshow(
512547
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
513548
if slice_through:
514549
traces = [
515-
go.Image(z=img_slice, zmin=zmin, zmax=zmax, colormodel=colormodel)
516-
for img_slice in img
550+
go.Image(
551+
z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel
552+
)
553+
for index_tup in itertools.product(*iterables)
517554
]
518555
else:
519556
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
@@ -533,9 +570,9 @@ def imshow(
533570
col_labels = []
534571
if facet_col is not None:
535572
slice_label = "slice" if labels.get("slice") is None else labels["slice"]
536-
if slices is None:
537-
slices = range(nslices)
538-
col_labels = ["%s = %d" % (slice_label, i) for i in slices]
573+
if facet_slices is None:
574+
facet_slices = range(nslices_facet)
575+
col_labels = ["%s = %d" % (slice_label, i) for i in facet_slices]
539576
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
540577
layout_patch = dict()
541578
for attr_name in ["height", "width"]:
@@ -547,11 +584,18 @@ def imshow(
547584
layout_patch["margin"] = {"t": 60}
548585

549586
frame_list = []
550-
for index, (slice_index, trace) in enumerate(zip(slices, traces)):
551-
if facet_col or index == 0:
587+
for index, trace in enumerate(traces):
588+
if (facet_col and index < nrows * ncols) or index == 0:
552589
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
553-
if animation_frame:
554-
frame_list.append(dict(data=trace, layout=layout, name=str(slice_index)))
590+
if animation_frame is not None:
591+
for i in range(nslices_animation):
592+
frame_list.append(
593+
dict(
594+
data=traces[nslices_facet * i : nslices_facet * (i + 1)],
595+
layout=layout,
596+
name=str(i),
597+
)
598+
)
555599
if animation_frame:
556600
fig.frames = frame_list
557601
fig.update_layout(layout)

0 commit comments

Comments
 (0)