1
1
import plotly .graph_objs as go
2
2
from _plotly_utils .basevalidators import ColorscaleValidator
3
- from ._core import apply_default_cascade , init_figure
3
+ from ._core import apply_default_cascade , init_figure , configure_animation_controls
4
4
from io import BytesIO
5
5
import base64
6
6
from .imshow_utils import rescale_intensity , _integer_ranges , _integer_types
@@ -133,7 +133,7 @@ def imshow(
133
133
labels = {},
134
134
x = None ,
135
135
y = None ,
136
- animation_frame = False ,
136
+ animation_frame = None ,
137
137
facet_col = None ,
138
138
facet_col_wrap = None ,
139
139
color_continuous_scale = None ,
@@ -353,13 +353,21 @@ def imshow(
353
353
354
354
# --------------- Starting from here img is always a numpy array --------
355
355
img = np .asanyarray (img )
356
+ slice_through = False
356
357
if facet_col is not None :
357
358
img = np .moveaxis (img , facet_col , 0 )
358
359
facet_col = True
359
-
360
+ slice_through = True
361
+ if animation_frame is not None :
362
+ img = np .moveaxis (img , animation_frame , 0 )
363
+ animation_frame = True
364
+ args ["animation_frame" ] = "plane"
365
+ slice_through = True
366
+
367
+ print ("slice_through" , slice_through )
360
368
# Default behaviour of binary_string: True for RGB images, False for 2D
361
369
if binary_string is None :
362
- if facet_col :
370
+ if slice_through :
363
371
binary_string = img .ndim >= 4 and not is_dataframe
364
372
else :
365
373
binary_string = img .ndim >= 3 and not is_dataframe
@@ -391,7 +399,7 @@ def imshow(
391
399
zmin = 0
392
400
393
401
# For 2d data, use Heatmap trace, unless binary_string is True
394
- if (img .ndim == 2 or (img .ndim == 3 and facet_col )) and not binary_string :
402
+ if (img .ndim == 2 or (img .ndim == 3 and slice_through )) and not binary_string :
395
403
if y is not None and img .shape [0 ] != len (y ):
396
404
raise ValueError (
397
405
"The length of the y vector must match the length of the first "
@@ -402,10 +410,10 @@ def imshow(
402
410
"The length of the x vector must match the length of the second "
403
411
+ "dimension of the img matrix."
404
412
)
405
- if facet_col :
413
+ if slice_through :
406
414
traces = [
407
- go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" )
408
- for img_slice in img
415
+ go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" , name = str ( i ) )
416
+ for i , img_slice in enumerate ( img )
409
417
]
410
418
else :
411
419
traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
@@ -429,7 +437,7 @@ def imshow(
429
437
# For 2D+RGB data, use Image trace
430
438
elif (
431
439
img .ndim == 3
432
- and (img .shape [- 1 ] in [3 , 4 ] or (facet_col and binary_string ))
440
+ and (img .shape [- 1 ] in [3 , 4 ] or (slice_through and binary_string ))
433
441
or (img .ndim == 2 and binary_string )
434
442
):
435
443
rescale_image = True # to check whether image has been modified
@@ -442,7 +450,7 @@ def imshow(
442
450
if zmin is None and zmax is None : # no rescaling, faster
443
451
img_rescaled = img
444
452
rescale_image = False
445
- elif img .ndim == 2 or (img .ndim == 3 and facet_col ):
453
+ elif img .ndim == 2 or (img .ndim == 3 and slice_through ):
446
454
img_rescaled = rescale_intensity (
447
455
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
448
456
)
@@ -457,7 +465,7 @@ def imshow(
457
465
for ch in range (img .shape [- 1 ])
458
466
]
459
467
)
460
- if facet_col :
468
+ if slice_through :
461
469
img_str = [
462
470
_array_to_b64str (
463
471
img_rescaled_slice ,
@@ -477,7 +485,7 @@ def imshow(
477
485
ext = binary_format ,
478
486
)
479
487
]
480
- traces = [go .Image (source = img_str_slice ) for img_str_slice in img_str ]
488
+ traces = [go .Image (source = img_str_slice , name = str ( i )) for i , img_str_slice in enumerate ( img_str ) ]
481
489
else :
482
490
colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
483
491
traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
@@ -498,8 +506,15 @@ def imshow(
498
506
layout_patch ["title_text" ] = args ["title" ]
499
507
elif args ["template" ].layout .margin .t is None :
500
508
layout_patch ["margin" ] = {"t" : 60 }
509
+
510
+ frame_list = []
501
511
for index , trace in enumerate (traces ):
502
- fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
512
+ if facet_col or index == 0 :
513
+ fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
514
+ if animation_frame :
515
+ frame_list .append (dict (data = trace , layout = layout , name = str (index )))
516
+ if animation_frame :
517
+ fig .frames = frame_list
503
518
fig .update_layout (layout )
504
519
fig .update_layout (layout_patch )
505
520
# Hover name, z or color
@@ -530,5 +545,6 @@ def imshow(
530
545
fig .update_xaxes (title_text = labels ["x" ])
531
546
if labels ["y" ]:
532
547
fig .update_yaxes (title_text = labels ["y" ])
533
- fig .update_layout (template = args ["template" ], overwrite = True )
548
+ configure_animation_controls (args , go .Image , fig )
549
+ #fig.update_layout(template=args["template"], overwrite=True)
534
550
return fig
0 commit comments