@@ -364,7 +364,6 @@ def imshow(
364
364
args ["animation_frame" ] = "plane"
365
365
slice_through = True
366
366
367
- print ("slice_through" , slice_through )
368
367
# Default behaviour of binary_string: True for RGB images, False for 2D
369
368
if binary_string is None :
370
369
if slice_through :
@@ -382,7 +381,11 @@ def imshow(
382
381
383
382
# -------- Contrast rescaling: either minmax or infer ------------------
384
383
if contrast_rescaling is None :
385
- contrast_rescaling = "minmax" if img .ndim == 2 else "infer"
384
+ contrast_rescaling = (
385
+ "minmax"
386
+ if (img .ndim == 2 or (img .ndim == 3 and slice_through ))
387
+ else "infer"
388
+ )
386
389
387
390
# We try to set zmin and zmax only if necessary, because traces have good defaults
388
391
if contrast_rescaling == "minmax" :
@@ -436,10 +439,8 @@ def imshow(
436
439
437
440
# For 2D+RGB data, use Image trace
438
441
elif (
439
- img .ndim == 3
440
- and (img .shape [- 1 ] in [3 , 4 ] or (slice_through and binary_string ))
441
- or (img .ndim == 2 and binary_string )
442
- ):
442
+ img .ndim >= 3 and (img .shape [- 1 ] in [3 , 4 ] or slice_through and binary_string )
443
+ ) or (img .ndim == 2 and binary_string ):
443
444
rescale_image = True # to check whether image has been modified
444
445
if zmin is not None and zmax is not None :
445
446
zmin , zmax = (
@@ -455,15 +456,16 @@ def imshow(
455
456
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
456
457
)
457
458
else :
458
- img_rescaled = np .dstack (
459
+ img_rescaled = np .stack (
459
460
[
460
461
rescale_intensity (
461
462
img [..., ch ],
462
463
in_range = (zmin [ch ], zmax [ch ]),
463
464
out_range = np .uint8 ,
464
465
)
465
466
for ch in range (img .shape [- 1 ])
466
- ]
467
+ ],
468
+ axis = - 1 ,
467
469
)
468
470
if slice_through :
469
471
img_str = [
@@ -485,10 +487,19 @@ def imshow(
485
487
ext = binary_format ,
486
488
)
487
489
]
488
- traces = [go .Image (source = img_str_slice , name = str (i )) for i , img_str_slice in enumerate (img_str )]
490
+ traces = [
491
+ go .Image (source = img_str_slice , name = str (i ))
492
+ for i , img_str_slice in enumerate (img_str )
493
+ ]
489
494
else :
490
495
colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
491
- traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
496
+ if slice_through :
497
+ traces = [
498
+ go .Image (z = img_slice , zmin = zmin , zmax = zmax , colormodel = colormodel )
499
+ for img_slice in img
500
+ ]
501
+ else :
502
+ traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
492
503
layout = {}
493
504
if origin == "lower" :
494
505
layout ["yaxis" ] = dict (autorange = True )
@@ -546,5 +557,5 @@ def imshow(
546
557
if labels ["y" ]:
547
558
fig .update_yaxes (title_text = labels ["y" ])
548
559
configure_animation_controls (args , go .Image , fig )
549
- #fig.update_layout(template=args["template"], overwrite=True)
560
+ # fig.update_layout(template=args["template"], overwrite=True)
550
561
return fig
0 commit comments