7
7
import pandas as pd
8
8
from .png import Writer , from_array
9
9
import numpy as np
10
+ import itertools
10
11
11
12
try :
12
13
import xarray
@@ -293,31 +294,41 @@ def imshow(
293
294
args = locals ()
294
295
apply_default_cascade (args )
295
296
labels = labels .copy ()
296
- nslices = 1
297
+ nslices_facet = 1
297
298
if facet_col is not None :
298
299
if isinstance (facet_col , str ):
299
300
facet_col = img .dims .index (facet_col )
300
- nslices = img .shape [facet_col ]
301
- ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices
302
- nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
301
+ nslices_facet = img .shape [facet_col ]
302
+ facet_slices = range (nslices_facet )
303
+ ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices_facet
304
+ nrows = (
305
+ nslices_facet // ncols + 1
306
+ if nslices_facet % ncols
307
+ else nslices_facet // ncols
308
+ )
303
309
else :
304
310
nrows = 1
305
311
ncols = 1
306
312
if animation_frame is not None :
307
313
if isinstance (animation_frame , str ):
308
314
animation_frame = img .dims .index (animation_frame )
309
- nslices = img .shape [animation_frame ]
315
+ nslices_animation = img .shape [animation_frame ]
316
+ animation_slices = range (nslices_animation )
310
317
slice_through = (facet_col is not None ) or (animation_frame is not None )
311
- slice_label = None
312
- slices = range (nslices )
318
+ double_slice_through = (facet_col is not None ) and (animation_frame is not None )
319
+ facet_label = None
320
+ animation_label = None
313
321
# ----- Define x and y, set labels if img is an xarray -------------------
314
322
if xarray_imported and isinstance (img , xarray .DataArray ):
315
323
dims = list (img .dims )
316
- if slice_through :
317
- slice_index = facet_col if facet_col is not None else animation_frame
318
- slices = img .coords [img .dims [slice_index ]].values
319
- _ = dims .pop (slice_index )
320
- slice_label = img .dims [slice_index ]
324
+ if facet_col is not None :
325
+ facet_slices = img .coords [img .dims [facet_col ]].values
326
+ _ = dims .pop (facet_col )
327
+ facet_label = img .dims [facet_col ]
328
+ if animation_frame is not None :
329
+ animation_slices = img .coords [img .dims [animation_frame ]].values
330
+ _ = dims .pop (animation_frame )
331
+ animation_label = img .dims [animation_frame ]
321
332
y_label , x_label = dims [0 ], dims [1 ]
322
333
# np.datetime64 is not handled correctly by go.Heatmap
323
334
for ax in [x_label , y_label ]:
@@ -333,8 +344,10 @@ def imshow(
333
344
labels ["x" ] = x_label
334
345
if labels .get ("y" , None ) is None :
335
346
labels ["y" ] = y_label
336
- if labels .get ("slice" , None ) is None :
337
- labels ["slice" ] = slice_label
347
+ if labels .get ("animation_slice" , None ) is None :
348
+ labels ["animation_slice" ] = animation_label
349
+ if labels .get ("facet_slice" , None ) is None :
350
+ labels ["facet_slice" ] = facet_label
338
351
if labels .get ("color" , None ) is None :
339
352
labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
340
353
labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -371,11 +384,15 @@ def imshow(
371
384
img = np .asanyarray (img )
372
385
if facet_col is not None :
373
386
img = np .moveaxis (img , facet_col , 0 )
387
+ print (img .shape )
388
+ if animation_frame is not None and animation_frame < facet_col :
389
+ animation_frame += 1
374
390
facet_col = True
375
391
if animation_frame is not None :
376
392
img = np .moveaxis (img , animation_frame , 0 )
393
+ print (img .shape )
377
394
animation_frame = True
378
- args ["animation_frame" ] = (
395
+ args ["animation_frame" ] = ( # TODO
379
396
"slice" if labels .get ("slice" ) is None else labels ["slice" ]
380
397
)
381
398
@@ -431,9 +448,16 @@ def imshow(
431
448
+ "dimension of the img matrix."
432
449
)
433
450
if slice_through :
451
+ iterables = ()
452
+ if animation_frame is not None :
453
+ iterables += (range (nslices_animation ),)
454
+ if facet_col is not None :
455
+ iterables += (range (nslices_facet ),)
434
456
traces = [
435
- go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" , name = str (i ))
436
- for i , img_slice in enumerate (img )
457
+ go .Heatmap (
458
+ x = x , y = y , z = img [index_tup ], coloraxis = "coloraxis1" , name = str (i )
459
+ )
460
+ for i , index_tup in enumerate (itertools .product (* iterables ))
437
461
]
438
462
else :
439
463
traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
@@ -464,11 +488,21 @@ def imshow(
464
488
_vectorize_zvalue (zmin , mode = "min" ),
465
489
_vectorize_zvalue (zmax , mode = "max" ),
466
490
)
491
+ if slice_through :
492
+ iterables = ()
493
+ if animation_frame is not None :
494
+ iterables += (range (nslices_animation ),)
495
+ if facet_col is not None :
496
+ iterables += (range (nslices_facet ),)
467
497
if binary_string :
468
498
if zmin is None and zmax is None : # no rescaling, faster
469
499
img_rescaled = img
470
500
rescale_image = False
471
- elif img .ndim == 2 or (img .ndim == 3 and slice_through ):
501
+ elif (
502
+ img .ndim == 2
503
+ or (img .ndim == 3 and slice_through )
504
+ or (img .ndim == 4 and double_slice_through )
505
+ ):
472
506
img_rescaled = rescale_intensity (
473
507
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
474
508
)
@@ -485,14 +519,15 @@ def imshow(
485
519
axis = - 1 ,
486
520
)
487
521
if slice_through :
522
+ tuples = [index_tup for index_tup in itertools .product (* iterables )]
488
523
img_str = [
489
524
_array_to_b64str (
490
- img_rescaled_slice ,
525
+ img_rescaled [ index_tup ] ,
491
526
backend = binary_backend ,
492
527
compression = binary_compression_level ,
493
528
ext = binary_format ,
494
529
)
495
- for img_rescaled_slice in img_rescaled
530
+ for index_tup in itertools . product ( * iterables )
496
531
]
497
532
498
533
else :
@@ -512,8 +547,10 @@ def imshow(
512
547
colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
513
548
if slice_through :
514
549
traces = [
515
- go .Image (z = img_slice , zmin = zmin , zmax = zmax , colormodel = colormodel )
516
- for img_slice in img
550
+ go .Image (
551
+ z = img [index_tup ], zmin = zmin , zmax = zmax , colormodel = colormodel
552
+ )
553
+ for index_tup in itertools .product (* iterables )
517
554
]
518
555
else :
519
556
traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
@@ -533,9 +570,9 @@ def imshow(
533
570
col_labels = []
534
571
if facet_col is not None :
535
572
slice_label = "slice" if labels .get ("slice" ) is None else labels ["slice" ]
536
- if slices is None :
537
- slices = range (nslices )
538
- col_labels = ["%s = %d" % (slice_label , i ) for i in slices ]
573
+ if facet_slices is None :
574
+ facet_slices = range (nslices_facet )
575
+ col_labels = ["%s = %d" % (slice_label , i ) for i in facet_slices ]
539
576
fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
540
577
layout_patch = dict ()
541
578
for attr_name in ["height" , "width" ]:
@@ -547,11 +584,18 @@ def imshow(
547
584
layout_patch ["margin" ] = {"t" : 60 }
548
585
549
586
frame_list = []
550
- for index , ( slice_index , trace ) in enumerate (zip ( slices , traces ) ):
551
- if facet_col or index == 0 :
587
+ for index , trace in enumerate (traces ):
588
+ if ( facet_col and index < nrows * ncols ) or index == 0 :
552
589
fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
553
- if animation_frame :
554
- frame_list .append (dict (data = trace , layout = layout , name = str (slice_index )))
590
+ if animation_frame is not None :
591
+ for i in range (nslices_animation ):
592
+ frame_list .append (
593
+ dict (
594
+ data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
595
+ layout = layout ,
596
+ name = str (i ),
597
+ )
598
+ )
555
599
if animation_frame :
556
600
fig .frames = frame_list
557
601
fig .update_layout (layout )
0 commit comments