@@ -290,14 +290,20 @@ def imshow(
290
290
labels = labels .copy ()
291
291
col_labels = []
292
292
if facet_col is not None :
293
+ if isinstance (facet_col , str ):
294
+ facet_col = img .dims .index (facet_col )
293
295
nslices = img .shape [facet_col ]
294
296
ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices
295
297
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
296
298
col_labels = ["plane = %d" % i for i in range (nslices )]
297
299
else :
298
300
nrows = 1
299
301
ncols = 1
302
+ if animation_frame is not None :
303
+ if isinstance (animation_frame , str ):
304
+ animation_frame = img .dims .index (animation_frame )
300
305
slice_through = (facet_col is not None ) or (animation_frame is not None )
306
+ plane_label = None
301
307
fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
302
308
# ----- Define x and y, set labels if img is an xarray -------------------
303
309
if xarray_imported and isinstance (img , xarray .DataArray ):
@@ -307,7 +313,14 @@ def imshow(
307
313
# "Please pass your data as a numpy array instead using"
308
314
# "`img.values`"
309
315
# )
310
- y_label , x_label = img .dims [0 ], img .dims [1 ]
316
+ dims = list (img .dims )
317
+ print (dims )
318
+ if slice_through :
319
+ slice_index = facet_col if facet_col is not None else animation_frame
320
+ _ = dims .pop (slice_index )
321
+ plane_label = img .dims [slice_index ]
322
+ y_label , x_label = dims [0 ], dims [1 ]
323
+ print (y_label , x_label )
311
324
# np.datetime64 is not handled correctly by go.Heatmap
312
325
for ax in [x_label , y_label ]:
313
326
if np .issubdtype (img .coords [ax ].dtype , np .datetime64 ):
@@ -322,6 +335,8 @@ def imshow(
322
335
labels ["x" ] = x_label
323
336
if labels .get ("y" , None ) is None :
324
337
labels ["y" ] = y_label
338
+ if labels .get ("plane" , None ) is None :
339
+ labels ["plane" ] = plane_label
325
340
if labels .get ("color" , None ) is None :
326
341
labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
327
342
labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -362,7 +377,9 @@ def imshow(
362
377
if animation_frame is not None :
363
378
img = np .moveaxis (img , animation_frame , 0 )
364
379
animation_frame = True
365
- args ["animation_frame" ] = "plane"
380
+ args ["animation_frame" ] = (
381
+ "plane" if labels .get ("plane" ) is None else labels ["plane" ]
382
+ )
366
383
367
384
# Default behaviour of binary_string: True for RGB images, False for 2D
368
385
if binary_string is None :
@@ -403,12 +420,14 @@ def imshow(
403
420
404
421
# For 2d data, use Heatmap trace, unless binary_string is True
405
422
if (img .ndim == 2 or (img .ndim == 3 and slice_through )) and not binary_string :
406
- if y is not None and img .shape [0 ] != len (y ):
423
+ y_index = 1 if slice_through else 0
424
+ if y is not None and img .shape [y_index ] != len (y ):
407
425
raise ValueError (
408
426
"The length of the y vector must match the length of the first "
409
427
+ "dimension of the img matrix."
410
428
)
411
- if x is not None and img .shape [1 ] != len (x ):
429
+ x_index = 2 if slice_through else 1
430
+ if x is not None and img .shape [x_index ] != len (x ):
412
431
raise ValueError (
413
432
"The length of the x vector must match the length of the second "
414
433
+ "dimension of the img matrix."
0 commit comments