Skip to content

Commit bd25f05

Browse files
committed
more tests and documented parameters in docstring
1 parent d3d88b8 commit bd25f05

File tree

2 files changed

+148
-65
lines changed

2 files changed

+148
-65
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,56 @@
11
import plotly.graph_objs as go
22
from _plotly_utils.basevalidators import ColorscaleValidator
33
from ._core import apply_default_cascade
4-
from PIL import Image
54
from io import BytesIO
65
import base64
76
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
87
import pandas as pd
98
import png
109
import numpy as np
10+
1111
try:
1212
import xarray
1313

1414
xarray_imported = True
1515
except ImportError:
1616
xarray_imported = False
17+
try:
18+
from PIL import Image
19+
20+
pil_imported = True
21+
except ImportError:
22+
pil_imported = False
1723

1824
_float_types = []
1925

2026

21-
def _array_to_b64str(img, backend='pil', compression=4):
27+
def _array_to_b64str(img, backend="pil", compression=4):
28+
"""Converts a numpy array of uint8 into a base64 png string.
29+
30+
Parameters
31+
----------
32+
img: ndarray of uint8
33+
array image
34+
backend: str
35+
'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
36+
otherwise pypng.
37+
compression: int, between 0 and 9
38+
compression level to be passed to the backend
39+
"""
40+
# PIL and pypng error messages are quite obscure so we catch invalid compression values
41+
if compression < 0 or compression > 9:
42+
raise ValueError("compression level must be between 0 and 9.")
2243
if img.ndim == 2:
23-
mode = 'L'
44+
mode = "L"
2445
elif img.ndim == 3 and img.shape[-1] == 3:
25-
mode = 'RGB'
46+
mode = "RGB"
2647
elif img.ndim == 3 and img.shape[-1] == 4:
27-
mode = 'RGBA'
48+
mode = "RGBA"
2849
else:
29-
raise ValueError(
30-
"Invalid image shape"
31-
)
32-
if backend=='png':
50+
raise ValueError("Invalid image shape")
51+
if backend == "auto":
52+
backend = "pil" if pil_imported else "pypng"
53+
if backend == "pypng":
3354
ndim = img.ndim
3455
sh = img.shape
3556
if ndim == 3:
@@ -40,11 +61,16 @@ def _array_to_b64str(img, backend='pil', compression=4):
4061
with BytesIO() as stream:
4162
w.write(stream, img_png.rows)
4263
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
43-
else:
64+
else: # pil
65+
if not pil_imported:
66+
raise ImportError(
67+
"pillow needs to be installed to use `backend='pil'. Please"
68+
"install pillow or use `backend='pypng'."
69+
)
4470
pil_img = Image.fromarray(img)
4571
prefix = "data:image/png;base64,"
4672
with BytesIO() as stream:
47-
pil_img.save(stream, format='png', compress_level=compression)
73+
pil_img.save(stream, format="png", compress_level=compression)
4874
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
4975
return base64_string
5076

@@ -100,8 +126,10 @@ def imshow(
100126
width=None,
101127
height=None,
102128
aspect=None,
103-
use_binary_string=None,
104129
contrast_rescaling=None,
130+
binary_string=None,
131+
binary_backend="auto",
132+
binary_compression_level=4,
105133
):
106134
"""
107135
Display an image, i.e. data on a 2D regular raster.
@@ -177,7 +205,13 @@ def imshow(
177205
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
178206
(which have typically heterogeneous coordinates)
179207
180-
use_binary_string: bool, default None
208+
contrast_rescaling: 'minmax', 'infer', or None
209+
how to determine data values corresponding to the bounds of the color
210+
range, when zmin or zmax are not passed. If `minmax`, the min and max
211+
values of the image are used. If `infer`, a heuristic based on the image
212+
data type is used.
213+
214+
binary_string: bool, default None
181215
if True, the image data are first rescaled and encoded as uint8 and
182216
then passed to plotly.js as a b64 PNG string. If False, data are passed
183217
unchanged as a numerical array. Setting to True may lead to performance
@@ -187,11 +221,18 @@ def imshow(
187221
represented as grayscale and with no colorbar if use_binary_string is
188222
True.
189223
190-
contrast_rescaling: 'minmax', 'infer', or None
191-
how to determine data values corresponding to the bounds of the color
192-
range, when zmin or zmax are not passed. If `minmax`, the min and max
193-
values of the image are used. If `infer`, a heuristic based on the image
194-
data type is used.
224+
binary_backend: str, 'auto' (default), 'pil' or 'pypng'
225+
Third-party package for the transformation of numpy arrays to
226+
png b64 strings. If 'auto', Pillow is used if installed, otherwise
227+
pypng.
228+
229+
binary_compression_level: int, between 0 and 9 (default 4)
230+
png compression level to be passed to the backend when transforming an
231+
array to a png b64 string. Increasing `binary_compression` decreases the
232+
size of the png string, but the compression step takes more time. For most
233+
images it is not worth using levels greater than 5, but it's possible to
234+
test `len(fig.data[0].source)` and to time the execution of `imshow` to
235+
tune the level of compression. 0 means no compression (not recommended).
195236
196237
Returns
197238
-------
@@ -217,7 +258,7 @@ def imshow(
217258
labels = labels.copy()
218259
# ----- Define x and y, set labels if img is an xarray -------------------
219260
if xarray_imported and isinstance(img, xarray.DataArray):
220-
if use_binary_string:
261+
if binary_string:
221262
raise ValueError(
222263
"It is not possible to use binary image strings for xarrays."
223264
"Please pass your data as a numpy array instead using"
@@ -262,24 +303,24 @@ def imshow(
262303
if aspect is None:
263304
aspect = "equal"
264305

265-
# Set the value of use_binary_string
306+
# Set the value of binary_string
266307
if isinstance(img, pd.DataFrame):
267-
if use_binary_string:
308+
if binary_string:
268309
raise ValueError("Binary strings cannot be used with pandas arrays")
269310
has_nans = True
270311
else:
271312
has_nans = np.any(np.isnan(img))
272-
if has_nans and use_binary_string:
313+
if has_nans and binary_string:
273314
raise ValueError(
274315
"Binary strings cannot be used with arrays containing NaNs"
275316
)
276317

277318
# --------------- Starting from here img is always a numpy array --------
278319
img = np.asanyarray(img)
279320

280-
# Default behaviour of use_binary_string: True for RGB images, False for 2D
281-
if use_binary_string is None:
282-
use_binary_string = img.ndim >= 3 and not has_nans
321+
# Default behaviour of binary_string: True for RGB images, False for 2D
322+
if binary_string is None:
323+
binary_string = img.ndim >= 3 and not has_nans
283324

284325
# Cast bools to uint8 (also one byte)
285326
if img.dtype == np.bool:
@@ -292,18 +333,18 @@ def imshow(
292333
contrast_rescaling = "minmax" if img.ndim == 2 else "infer"
293334

294335
if contrast_rescaling == "minmax":
295-
if (zmin is not None or use_binary_string) and zmax is None:
336+
if (zmin is not None or binary_string) and zmax is None:
296337
zmax = img.max()
297-
if (zmax is not None or use_binary_string) and zmin is None:
338+
if (zmax is not None or binary_string) and zmin is None:
298339
zmin = img.min()
299340
else:
300-
if zmax is None and img.dtype is not np.uint8:
341+
if zmax is None and (img.dtype is not np.uint8 or img.ndim == 2):
301342
zmax = _infer_zmax_from_type(img)
302343
if zmin is None:
303344
zmin = 0
304345

305-
# For 2d data, use Heatmap trace, unless use_binary_string is True
306-
if img.ndim == 2 and not use_binary_string:
346+
# For 2d data, use Heatmap trace, unless binary_string is True
347+
if img.ndim == 2 and not binary_string:
307348
if y is not None and img.shape[0] != len(y):
308349
raise ValueError(
309350
"The length of the y vector must match the length of the first "
@@ -333,13 +374,9 @@ def imshow(
333374
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
334375

335376
# For 2D+RGB data, use Image trace
336-
elif (
337-
img.ndim == 3
338-
and img.shape[-1] in [3, 4]
339-
or (img.ndim == 2 and use_binary_string)
340-
):
377+
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string):
341378
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
342-
if use_binary_string:
379+
if binary_string:
343380
if img.ndim == 2:
344381
img_rescaled = rescale_intensity(
345382
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
@@ -357,7 +394,11 @@ def imshow(
357394
)
358395
if origin == "lower":
359396
img_rescaled = img_rescaled[::-1]
360-
img_str = _array_to_b64str(img_rescaled)
397+
img_str = _array_to_b64str(
398+
img_rescaled,
399+
backend=binary_backend,
400+
compression=binary_compression_level,
401+
)
361402
trace = go.Image(source=img_str)
362403
else:
363404
trace = go.Image(z=img, zmin=zmin, zmax=zmax)
@@ -380,8 +421,13 @@ def imshow(
380421
layout_patch["margin"] = {"t": 60}
381422
fig = go.Figure(data=trace, layout=layout)
382423
fig.update_layout(layout_patch)
383-
# does not work yet, Antoine working on it
384-
hover_name = "z" if not use_binary_string else "colorLabel"
424+
# Hover name, z or color
425+
if trace["type"] == "heatmap":
426+
hover_name = "z"
427+
elif img.ndim == 2:
428+
hover_name = "color[0]"
429+
else:
430+
hover_name = "color"
385431
fig.update_traces(
386432
hovertemplate="%s: %%{x}<br>%s: %%{y}<br>%s: %%{%s}<extra></extra>"
387433
% (

0 commit comments

Comments
 (0)