Skip to content

Commit 882810f

Browse files
committed
animation work for xarrays, still need to fix slider label
1 parent fbb3f65 commit 882810f

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

doc/python/imshow.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,19 @@ fig = px.imshow(img, animation_frame=0, binary_string=True)
451451
fig.show()
452452
```
453453

454+
### Animations of xarray datasets
455+
456+
*Introduced in plotly 4.11*
457+
458+
```python
459+
import plotly.express as px
460+
import xarray as xr
461+
# Load xarray from dataset included in the xarray tutorial
462+
ds = xr.tutorial.open_dataset('air_temperature').air[:20]
463+
fig = px.imshow(ds, animation_frame='lat', color_continuous_scale='RdBu_r')
464+
fig.show()
465+
```
466+
454467
#### Reference
455468

456469
See https://plotly.com/python/reference/image/ for more information and chart attribute options!

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,20 @@ def imshow(
290290
labels = labels.copy()
291291
col_labels = []
292292
if facet_col is not None:
293+
if isinstance(facet_col, str):
294+
facet_col = img.dims.index(facet_col)
293295
nslices = img.shape[facet_col]
294296
ncols = int(facet_col_wrap) if facet_col_wrap is not None else nslices
295297
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
296298
col_labels = ["plane = %d" % i for i in range(nslices)]
297299
else:
298300
nrows = 1
299301
ncols = 1
302+
if animation_frame is not None:
303+
if isinstance(animation_frame, str):
304+
animation_frame = img.dims.index(animation_frame)
300305
slice_through = (facet_col is not None) or (animation_frame is not None)
306+
plane_label = None
301307
fig = init_figure(args, "xy", [], nrows, ncols, col_labels, [])
302308
# ----- Define x and y, set labels if img is an xarray -------------------
303309
if xarray_imported and isinstance(img, xarray.DataArray):
@@ -307,7 +313,14 @@ def imshow(
307313
# "Please pass your data as a numpy array instead using"
308314
# "`img.values`"
309315
# )
310-
y_label, x_label = img.dims[0], img.dims[1]
316+
dims = list(img.dims)
317+
print(dims)
318+
if slice_through:
319+
slice_index = facet_col if facet_col is not None else animation_frame
320+
_ = dims.pop(slice_index)
321+
plane_label = img.dims[slice_index]
322+
y_label, x_label = dims[0], dims[1]
323+
print(y_label, x_label)
311324
# np.datetime64 is not handled correctly by go.Heatmap
312325
for ax in [x_label, y_label]:
313326
if np.issubdtype(img.coords[ax].dtype, np.datetime64):
@@ -322,6 +335,8 @@ def imshow(
322335
labels["x"] = x_label
323336
if labels.get("y", None) is None:
324337
labels["y"] = y_label
338+
if labels.get("plane", None) is None:
339+
labels["plane"] = plane_label
325340
if labels.get("color", None) is None:
326341
labels["color"] = xarray.plot.utils.label_from_attrs(img)
327342
labels["color"] = labels["color"].replace("\n", "<br>")
@@ -362,7 +377,9 @@ def imshow(
362377
if animation_frame is not None:
363378
img = np.moveaxis(img, animation_frame, 0)
364379
animation_frame = True
365-
args["animation_frame"] = "plane"
380+
args["animation_frame"] = (
381+
"plane" if labels.get("plane") is None else labels["plane"]
382+
)
366383

367384
# Default behaviour of binary_string: True for RGB images, False for 2D
368385
if binary_string is None:
@@ -403,12 +420,14 @@ def imshow(
403420

404421
# For 2d data, use Heatmap trace, unless binary_string is True
405422
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string:
406-
if y is not None and img.shape[0] != len(y):
423+
y_index = 1 if slice_through else 0
424+
if y is not None and img.shape[y_index] != len(y):
407425
raise ValueError(
408426
"The length of the y vector must match the length of the first "
409427
+ "dimension of the img matrix."
410428
)
411-
if x is not None and img.shape[1] != len(x):
429+
x_index = 2 if slice_through else 1
430+
if x is not None and img.shape[x_index] != len(x):
412431
raise ValueError(
413432
"The length of the x vector must match the length of the second "
414433
+ "dimension of the img matrix."

0 commit comments

Comments
 (0)