Skip to content

Commit c7285a3

Browse files
committed
simplified code
1 parent 59c6622 commit c7285a3

File tree

2 files changed

+27
-42
lines changed

2 files changed

+27
-42
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,11 @@ def imshow(
434434
zmin = 0
435435

436436
# For 2d data, use Heatmap trace, unless binary_string is True
437-
if (img.ndim == 2 or (img.ndim == 3 and slice_through)) and not binary_string:
437+
if (
438+
img.ndim == 2
439+
or (img.ndim == 3 and slice_through)
440+
or (img.ndim == 4 and double_slice_through)
441+
) and not binary_string:
438442
y_index = 1 if slice_through else 0
439443
if y is not None and img.shape[y_index] != len(y):
440444
raise ValueError(
@@ -447,20 +451,16 @@ def imshow(
447451
"The length of the x vector must match the length of the second "
448452
+ "dimension of the img matrix."
449453
)
454+
iterables = ()
450455
if slice_through:
451-
iterables = ()
452456
if animation_frame is not None:
453457
iterables += (range(nslices_animation),)
454458
if facet_col is not None:
455459
iterables += (range(nslices_facet),)
456-
traces = [
457-
go.Heatmap(
458-
x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i)
459-
)
460-
for i, index_tup in enumerate(itertools.product(*iterables))
461-
]
462-
else:
463-
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")]
460+
traces = [
461+
go.Heatmap(x=x, y=y, z=img[index_tup], coloraxis="coloraxis1", name=str(i))
462+
for i, index_tup in enumerate(itertools.product(*iterables))
463+
]
464464
autorange = True if origin == "lower" else "reversed"
465465
layout = dict(yaxis=dict(autorange=autorange))
466466
if aspect == "equal":
@@ -488,8 +488,8 @@ def imshow(
488488
_vectorize_zvalue(zmin, mode="min"),
489489
_vectorize_zvalue(zmax, mode="max"),
490490
)
491+
iterables = ()
491492
if slice_through:
492-
iterables = ()
493493
if animation_frame is not None:
494494
iterables += (range(nslices_animation),)
495495
if facet_col is not None:
@@ -518,42 +518,26 @@ def imshow(
518518
],
519519
axis=-1,
520520
)
521-
if slice_through:
522-
tuples = [index_tup for index_tup in itertools.product(*iterables)]
523-
img_str = [
524-
_array_to_b64str(
525-
img_rescaled[index_tup],
526-
backend=binary_backend,
527-
compression=binary_compression_level,
528-
ext=binary_format,
529-
)
530-
for index_tup in itertools.product(*iterables)
531-
]
521+
img_str = [
522+
_array_to_b64str(
523+
img_rescaled[index_tup],
524+
backend=binary_backend,
525+
compression=binary_compression_level,
526+
ext=binary_format,
527+
)
528+
for index_tup in itertools.product(*iterables)
529+
]
532530

533-
else:
534-
img_str = [
535-
_array_to_b64str(
536-
img_rescaled,
537-
backend=binary_backend,
538-
compression=binary_compression_level,
539-
ext=binary_format,
540-
)
541-
]
542531
traces = [
543532
go.Image(source=img_str_slice, name=str(i))
544533
for i, img_str_slice in enumerate(img_str)
545534
]
546535
else:
547536
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
548-
if slice_through:
549-
traces = [
550-
go.Image(
551-
z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel
552-
)
553-
for index_tup in itertools.product(*iterables)
554-
]
555-
else:
556-
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
537+
traces = [
538+
go.Image(z=img[index_tup], zmin=zmin, zmax=zmax, colormodel=colormodel)
539+
for index_tup in itertools.product(*iterables)
540+
]
557541
layout = {}
558542
if origin == "lower":
559543
layout["yaxis"] = dict(autorange=True)

packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,10 @@ def test_animation_frame_rgb(animation_frame, binary_string):
366366
assert len(fig.frames) == nslices
367367

368368

369-
def test_animation_and_facet():
369+
@pytest.mark.parametrize("binary_string", [False, True])
370+
def test_animation_and_facet(binary_string):
370371
img = np.random.randint(255, size=(10, 9, 8, 7)).astype(np.uint8)
371-
fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=True)
372+
fig = px.imshow(img, animation_frame=0, facet_col=1, binary_string=binary_string)
372373
nslices = img.shape[0]
373374
assert len(fig.frames) == nslices
374375
assert len(fig.data) == img.shape[1]

0 commit comments

Comments
 (0)