Skip to content

Commit 8be8ca0

Browse files
committed
WIP: add facet_col arg to imshow
1 parent afb5c4d commit 8be8ca0

File tree

2 files changed

+80
-20
lines changed

2 files changed

+80
-20
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def imshow(
134134
x=None,
135135
y=None,
136136
animation_frame=False,
137-
facet_col=False,
137+
facet_col=None,
138138
facet_col_wrap=None,
139139
color_continuous_scale=None,
140140
color_continuous_midpoint=None,
@@ -189,6 +189,14 @@ def imshow(
189189
their lengths must match the lengths of the second and first dimensions of the
190190
img argument. They are auto-populated if the input is an xarray.
191191
192+
facet_col: int, optional (default None)
193+
axis number along which the image array is slices to create a facetted plot.
194+
195+
facet_col_wrap: int
196+
Maximum number of facet columns. Wraps the column variable at this width,
197+
so that the column facets span multiple rows.
198+
Ignored if `facet_col` is None.
199+
192200
color_continuous_scale : str or list of str
193201
colormap used to map scalar data to colors (for a 2D image). This parameter is
194202
not used for RGB or RGBA images. If a string is provided, it should be the name
@@ -280,14 +288,14 @@ def imshow(
280288
args = locals()
281289
apply_default_cascade(args)
282290
labels = labels.copy()
283-
if facet_col:
284-
nslices = img.shape[-1]
285-
ncols = facet_col_wrap
286-
nrows = nslices / ncols
291+
if facet_col is not None:
292+
nslices = img.shape[facet_col]
293+
ncols = int(facet_col_wrap)
294+
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
287295
else:
288296
nrows = 1
289297
ncols = 1
290-
fig = init_figure(args, 'xy', [], nrows, ncols, [], [])
298+
fig = init_figure(args, "xy", [], nrows, ncols, [], [])
291299
# ----- Define x and y, set labels if img is an xarray -------------------
292300
if xarray_imported and isinstance(img, xarray.DataArray):
293301
if binary_string:
@@ -345,10 +353,16 @@ def imshow(
345353

346354
# --------------- Starting from here img is always a numpy array --------
347355
img = np.asanyarray(img)
356+
if facet_col is not None:
357+
img = np.moveaxis(img, facet_col, 0)
358+
facet_col = True
348359

349360
# Default behaviour of binary_string: True for RGB images, False for 2D
350361
if binary_string is None:
351-
binary_string = img.ndim >= 3 and not is_dataframe
362+
if facet_col:
363+
binary_string = img.ndim >= 4 and not is_dataframe
364+
else:
365+
binary_string = img.ndim >= 3 and not is_dataframe
352366

353367
# Cast bools to uint8 (also one byte)
354368
if img.dtype == np.bool:
@@ -377,7 +391,7 @@ def imshow(
377391
zmin = 0
378392

379393
# For 2d data, use Heatmap trace, unless binary_string is True
380-
if img.ndim == 2 and not binary_string:
394+
if (img.ndim == 2 or (img.ndim == 3 and facet_col)) and not binary_string:
381395
if y is not None and img.shape[0] != len(y):
382396
raise ValueError(
383397
"The length of the y vector must match the length of the first "
@@ -388,7 +402,13 @@ def imshow(
388402
"The length of the x vector must match the length of the second "
389403
+ "dimension of the img matrix."
390404
)
391-
trace = go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")
405+
if facet_col:
406+
traces = [
407+
go.Heatmap(x=x, y=y, z=img_slice, coloraxis="coloraxis1")
408+
for img_slice in img
409+
]
410+
else:
411+
traces = [go.Heatmap(x=x, y=y, z=img, coloraxis="coloraxis1")]
392412
autorange = True if origin == "lower" else "reversed"
393413
layout = dict(yaxis=dict(autorange=autorange))
394414
if aspect == "equal":
@@ -407,7 +427,11 @@ def imshow(
407427
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
408428

409429
# For 2D+RGB data, use Image trace
410-
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string):
430+
elif (
431+
img.ndim == 3
432+
and (img.shape[-1] in [3, 4] or (facet_col and binary_string))
433+
or (img.ndim == 2 and binary_string)
434+
):
411435
rescale_image = True # to check whether image has been modified
412436
if zmin is not None and zmax is not None:
413437
zmin, zmax = (
@@ -418,7 +442,7 @@ def imshow(
418442
if zmin is None and zmax is None: # no rescaling, faster
419443
img_rescaled = img
420444
rescale_image = False
421-
elif img.ndim == 2:
445+
elif img.ndim == 2 or (img.ndim == 3 and facet_col):
422446
img_rescaled = rescale_intensity(
423447
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
424448
)
@@ -433,16 +457,30 @@ def imshow(
433457
for ch in range(img.shape[-1])
434458
]
435459
)
436-
img_str = _array_to_b64str(
437-
img_rescaled,
438-
backend=binary_backend,
439-
compression=binary_compression_level,
440-
ext=binary_format,
441-
)
442-
trace = go.Image(source=img_str)
460+
if facet_col:
461+
img_str = [
462+
_array_to_b64str(
463+
img_rescaled_slice,
464+
backend=binary_backend,
465+
compression=binary_compression_level,
466+
ext=binary_format,
467+
)
468+
for img_rescaled_slice in img_rescaled
469+
]
470+
471+
else:
472+
img_str = [
473+
_array_to_b64str(
474+
img_rescaled,
475+
backend=binary_backend,
476+
compression=binary_compression_level,
477+
ext=binary_format,
478+
)
479+
]
480+
traces = [go.Image(source=img_str_slice) for img_str_slice in img_str]
443481
else:
444482
colormodel = "rgb" if img.shape[-1] == 3 else "rgba256"
445-
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
483+
traces = [go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)]
446484
layout = {}
447485
if origin == "lower":
448486
layout["yaxis"] = dict(autorange=True)
@@ -460,7 +498,8 @@ def imshow(
460498
layout_patch["title_text"] = args["title"]
461499
elif args["template"].layout.margin.t is None:
462500
layout_patch["margin"] = {"t": 60}
463-
fig.add_trace(trace)
501+
for index, trace in enumerate(traces):
502+
fig.add_trace(trace, row=nrows - index // ncols, col=index % ncols + 1)
464503
fig.update_layout(layout)
465504
fig.update_layout(layout_patch)
466505
# Hover name, z or color

packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,24 @@ def test_imshow_hovertemplate(binary_string):
314314
fig.data[0].hovertemplate
315315
== "x: %{x}<br>y: %{y}<br>color: %{z}<extra></extra>"
316316
)
317+
318+
319+
@pytest.mark.parametrize("facet_col", [0, 1, 2, -1])
320+
@pytest.mark.parametrize("binary_string", [False, True])
321+
def test_facet_col(facet_col, binary_string):
322+
img = np.random.randint(255, size=(10, 9, 8))
323+
facet_col_wrap = 3
324+
fig = px.imshow(
325+
img,
326+
facet_col=facet_col,
327+
facet_col_wrap=facet_col_wrap,
328+
binary_string=binary_string,
329+
)
330+
if facet_col is not None:
331+
nslices = img.shape[facet_col]
332+
ncols = int(facet_col_wrap)
333+
nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
334+
nmax = ncols * nrows
335+
assert "yaxis%d" % nmax in fig.layout
336+
assert "yaxis%d" % (nmax + 1) not in fig.layout
337+
assert len(fig.data) == nslices

0 commit comments

Comments
 (0)