@@ -134,7 +134,7 @@ def imshow(
134
134
x = None ,
135
135
y = None ,
136
136
animation_frame = False ,
137
- facet_col = False ,
137
+ facet_col = None ,
138
138
facet_col_wrap = None ,
139
139
color_continuous_scale = None ,
140
140
color_continuous_midpoint = None ,
@@ -189,6 +189,14 @@ def imshow(
189
189
their lengths must match the lengths of the second and first dimensions of the
190
190
img argument. They are auto-populated if the input is an xarray.
191
191
192
+ facet_col: int, optional (default None)
193
+ axis number along which the image array is slices to create a facetted plot.
194
+
195
+ facet_col_wrap: int
196
+ Maximum number of facet columns. Wraps the column variable at this width,
197
+ so that the column facets span multiple rows.
198
+ Ignored if `facet_col` is None.
199
+
192
200
color_continuous_scale : str or list of str
193
201
colormap used to map scalar data to colors (for a 2D image). This parameter is
194
202
not used for RGB or RGBA images. If a string is provided, it should be the name
@@ -280,14 +288,14 @@ def imshow(
280
288
args = locals ()
281
289
apply_default_cascade (args )
282
290
labels = labels .copy ()
283
- if facet_col :
284
- nslices = img .shape [- 1 ]
285
- ncols = facet_col_wrap
286
- nrows = nslices / ncols
291
+ if facet_col is not None :
292
+ nslices = img .shape [facet_col ]
293
+ ncols = int ( facet_col_wrap )
294
+ nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
287
295
else :
288
296
nrows = 1
289
297
ncols = 1
290
- fig = init_figure (args , 'xy' , [], nrows , ncols , [], [])
298
+ fig = init_figure (args , "xy" , [], nrows , ncols , [], [])
291
299
# ----- Define x and y, set labels if img is an xarray -------------------
292
300
if xarray_imported and isinstance (img , xarray .DataArray ):
293
301
if binary_string :
@@ -345,10 +353,16 @@ def imshow(
345
353
346
354
# --------------- Starting from here img is always a numpy array --------
347
355
img = np .asanyarray (img )
356
+ if facet_col is not None :
357
+ img = np .moveaxis (img , facet_col , 0 )
358
+ facet_col = True
348
359
349
360
# Default behaviour of binary_string: True for RGB images, False for 2D
350
361
if binary_string is None :
351
- binary_string = img .ndim >= 3 and not is_dataframe
362
+ if facet_col :
363
+ binary_string = img .ndim >= 4 and not is_dataframe
364
+ else :
365
+ binary_string = img .ndim >= 3 and not is_dataframe
352
366
353
367
# Cast bools to uint8 (also one byte)
354
368
if img .dtype == np .bool :
@@ -377,7 +391,7 @@ def imshow(
377
391
zmin = 0
378
392
379
393
# For 2d data, use Heatmap trace, unless binary_string is True
380
- if img .ndim == 2 and not binary_string :
394
+ if ( img .ndim == 2 or ( img . ndim == 3 and facet_col )) and not binary_string :
381
395
if y is not None and img .shape [0 ] != len (y ):
382
396
raise ValueError (
383
397
"The length of the y vector must match the length of the first "
@@ -388,7 +402,13 @@ def imshow(
388
402
"The length of the x vector must match the length of the second "
389
403
+ "dimension of the img matrix."
390
404
)
391
- trace = go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )
405
+ if facet_col :
406
+ traces = [
407
+ go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" )
408
+ for img_slice in img
409
+ ]
410
+ else :
411
+ traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
392
412
autorange = True if origin == "lower" else "reversed"
393
413
layout = dict (yaxis = dict (autorange = autorange ))
394
414
if aspect == "equal" :
@@ -407,7 +427,11 @@ def imshow(
407
427
layout ["coloraxis1" ]["colorbar" ] = dict (title_text = labels ["color" ])
408
428
409
429
# For 2D+RGB data, use Image trace
410
- elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ] or (img .ndim == 2 and binary_string ):
430
+ elif (
431
+ img .ndim == 3
432
+ and (img .shape [- 1 ] in [3 , 4 ] or (facet_col and binary_string ))
433
+ or (img .ndim == 2 and binary_string )
434
+ ):
411
435
rescale_image = True # to check whether image has been modified
412
436
if zmin is not None and zmax is not None :
413
437
zmin , zmax = (
@@ -418,7 +442,7 @@ def imshow(
418
442
if zmin is None and zmax is None : # no rescaling, faster
419
443
img_rescaled = img
420
444
rescale_image = False
421
- elif img .ndim == 2 :
445
+ elif img .ndim == 2 or ( img . ndim == 3 and facet_col ) :
422
446
img_rescaled = rescale_intensity (
423
447
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
424
448
)
@@ -433,16 +457,30 @@ def imshow(
433
457
for ch in range (img .shape [- 1 ])
434
458
]
435
459
)
436
- img_str = _array_to_b64str (
437
- img_rescaled ,
438
- backend = binary_backend ,
439
- compression = binary_compression_level ,
440
- ext = binary_format ,
441
- )
442
- trace = go .Image (source = img_str )
460
+ if facet_col :
461
+ img_str = [
462
+ _array_to_b64str (
463
+ img_rescaled_slice ,
464
+ backend = binary_backend ,
465
+ compression = binary_compression_level ,
466
+ ext = binary_format ,
467
+ )
468
+ for img_rescaled_slice in img_rescaled
469
+ ]
470
+
471
+ else :
472
+ img_str = [
473
+ _array_to_b64str (
474
+ img_rescaled ,
475
+ backend = binary_backend ,
476
+ compression = binary_compression_level ,
477
+ ext = binary_format ,
478
+ )
479
+ ]
480
+ traces = [go .Image (source = img_str_slice ) for img_str_slice in img_str ]
443
481
else :
444
482
colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
445
- trace = go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )
483
+ traces = [ go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
446
484
layout = {}
447
485
if origin == "lower" :
448
486
layout ["yaxis" ] = dict (autorange = True )
@@ -460,7 +498,8 @@ def imshow(
460
498
layout_patch ["title_text" ] = args ["title" ]
461
499
elif args ["template" ].layout .margin .t is None :
462
500
layout_patch ["margin" ] = {"t" : 60 }
463
- fig .add_trace (trace )
501
+ for index , trace in enumerate (traces ):
502
+ fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
464
503
fig .update_layout (layout )
465
504
fig .update_layout (layout_patch )
466
505
# Hover name, z or color
0 commit comments