Skip to content

Commit 34cc0de

Browse files
committed
more tests and docstring
1 parent e551672 commit 34cc0de

File tree

2 files changed

+117
-32
lines changed

2 files changed

+117
-32
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,20 @@
3333
_integer_ranges = {t: (np.iinfo(t).min, np.iinfo(t).max) for t in _integer_types}
3434

3535

36-
def _array_to_b64str(img, ext='png'):
36+
def _array_to_b64str(img, ext="png"):
3737
pil_img = Image.fromarray(img)
3838
if ext == "jpg":
39-
ext = "jpeg"
39+
ext = "jpeg"
4040
buff = BytesIO()
4141
pil_img.save(buff, format=ext)
42-
if ext == 'png':
43-
prefix = b'data:image/png;base64,'
44-
elif ext == 'jpeg':
45-
prefix = b'data:image/jpeg;base64,'
42+
if ext == "png":
43+
prefix = b"data:image/png;base64,"
44+
elif ext == "jpeg":
45+
prefix = b"data:image/jpeg;base64,"
4646
else:
47-
raise ValueError("accepted image formats are 'png' and 'jpeg' but %s was passed" %format)
47+
raise ValueError(
48+
"accepted image formats are 'png' and 'jpeg' but %s was passed" % format
49+
)
4850
image_string = (prefix + base64.b64encode(buff.getvalue())).decode("utf-8")
4951
return image_string
5052

@@ -177,6 +179,22 @@ def imshow(
177179
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
178180
(which have typically heterogeneous coordinates)
179181
182+
use_binary_string: bool, default None
183+
if True, the image data are first rescaled and encoded as uint8 and
184+
then passed to plotly.js as a b64 PNG string. If False, data are passed
185+
unchanged as a numerical array. Setting to True may lead to performance
186+
gains, at the cost of a loss of precision depending on the original data
187+
type. If None, use_binary_string is set to True for multichannel (eg) RGB
188+
arrays, and to False for single-channel (2D) arrays. 2D arrays are
189+
represented as grayscale and with no colorbar if use_binary_string is
190+
True.
191+
192+
contrast_rescaling: 'image', 'dtype', or None
193+
how to determine data values corresponding to the bounds of the color
194+
range, when zmin or zmax are not passed. If `image`, the min and max
195+
values of the image are used. If `dtype`, a heuristic based on the image
196+
data type is used.
197+
180198
Returns
181199
-------
182200
fig : graph_objects.Figure containing the displayed image
@@ -203,10 +221,10 @@ def imshow(
203221
if xarray_imported and isinstance(img, xarray.DataArray):
204222
if use_binary_string:
205223
raise ValueError(
206-
"It is not possible to use binary image strings for xarrays."
207-
"Please pass your data as a numpy array instead using"
208-
"`img.values`"
209-
)
224+
"It is not possible to use binary image strings for xarrays."
225+
"Please pass your data as a numpy array instead using"
226+
"`img.values`"
227+
)
210228
y_label, x_label = img.dims[0], img.dims[1]
211229
# np.datetime64 is not handled correctly by go.Heatmap
212230
for ax in [x_label, y_label]:
@@ -254,7 +272,9 @@ def imshow(
254272
else:
255273
has_nans = np.any(np.isnan(img))
256274
if has_nans and use_binary_string:
257-
raise ValueError("Binary strings cannot be used with arrays containing NaNs")
275+
raise ValueError(
276+
"Binary strings cannot be used with arrays containing NaNs"
277+
)
258278

259279
# --------------- Starting from here img is always a numpy array --------
260280
img = np.asanyarray(img)
@@ -268,8 +288,8 @@ def imshow(
268288
img = 255 * img.astype(np.uint8)
269289

270290
if contrast_rescaling is None:
271-
contrast_rescaling='image' if img.ndim == 2 else 'dtype'
272-
if contrast_rescaling == 'image':
291+
contrast_rescaling = "image" if img.ndim == 2 else "dtype"
292+
if contrast_rescaling == "image":
273293
if (zmin is not None or use_binary_string) and zmax is None:
274294
zmax = img.max()
275295
if (zmax is not None or use_binary_string) and zmin is None:
@@ -312,14 +332,30 @@ def imshow(
312332
layout["coloraxis1"]["colorbar"] = dict(title_text=labels["color"])
313333

314334
# For 2D+RGB data, use Image trace
315-
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and use_binary_string):
335+
elif (
336+
img.ndim == 3
337+
and img.shape[-1] in [3, 4]
338+
or (img.ndim == 2 and use_binary_string)
339+
):
316340
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
317341
if use_binary_string:
318342
if img.ndim == 2:
319-
img_rescaled = rescale_intensity(img, in_range=(zmin[0], zmax[0]), out_range=np.uint8)
343+
img_rescaled = rescale_intensity(
344+
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
345+
)
320346
else:
321-
img_rescaled = np.dstack([rescale_intensity(img[..., ch], in_range=(zmin[ch], zmax[ch]), out_range=np.uint8)
322-
for ch in range(img.shape[-1])])
347+
img_rescaled = np.dstack(
348+
[
349+
rescale_intensity(
350+
img[..., ch],
351+
in_range=(zmin[ch], zmax[ch]),
352+
out_range=np.uint8,
353+
)
354+
for ch in range(img.shape[-1])
355+
]
356+
)
357+
if origin == "lower":
358+
img_rescaled = img_rescaled[::-1]
323359
img_str = _array_to_b64str(img_rescaled)
324360
trace = go.Image(source=img_str)
325361
else:
@@ -343,11 +379,17 @@ def imshow(
343379
layout_patch["margin"] = {"t": 60}
344380
fig = go.Figure(data=trace, layout=layout)
345381
fig.update_layout(layout_patch)
346-
if not use_binary_string:
347-
fig.update_traces(
348-
hovertemplate="%s: %%{x}<br>%s: %%{y}<br>%s: %%{z}<extra></extra>"
349-
% (labels["x"] or "x", labels["y"] or "y", labels["color"] or "color",)
382+
# does not work yet, Antoine working on it
383+
hover_name = "z" if not use_binary_string else "colorLabel"
384+
fig.update_traces(
385+
hovertemplate="%s: %%{x}<br>%s: %%{y}<br>%s: %%{%s}<extra></extra>"
386+
% (
387+
labels["x"] or "x",
388+
labels["y"] or "y",
389+
labels["color"] or "color",
390+
hover_name,
350391
)
392+
)
351393
if labels["x"]:
352394
fig.update_xaxes(title_text=labels["x"])
353395
if labels["y"]:

packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def decode_image_string(image_string):
1515
"""
1616
Converts image string to numpy array.
1717
"""
18-
if 'png' in image_string[:22]:
18+
if "png" in image_string[:22]:
1919
return np.asarray(Image.open(BytesIO(base64.b64decode(image_string[22:]))))
20-
elif 'jpeg' in image_string[:23]:
20+
elif "jpeg" in image_string[:23]:
2121
return np.asaray(Image.open(BytesIO(base64.b64decode(image_string[23:]))))
2222
else:
2323
raise ValueError("image string format not recognized")
@@ -57,14 +57,20 @@ def test_automatic_zmax_from_dtype():
5757
assert fig.data[0]["zmax"] == (val, val, val, 1)
5858

5959

60-
def test_origin():
61-
for img in [img_rgb, img_gray]:
62-
fig = px.imshow(img, origin="lower")
60+
@pytest.mark.parametrize("use_binary_string", [False, True])
61+
def test_origin(use_binary_string):
62+
for i, img in enumerate([img_rgb, img_gray]):
63+
fig = px.imshow(img, origin="lower", use_binary_string=use_binary_string)
6364
assert fig.layout.yaxis.autorange == True
64-
fig = px.imshow(img_rgb)
65+
if use_binary_string and i == 0:
66+
assert np.all(img[::-1] == decode_image_string(fig.data[0].source))
67+
fig = px.imshow(img_rgb, use_binary_string=use_binary_string)
6568
assert fig.layout.yaxis.autorange is None
66-
fig = px.imshow(img_gray)
67-
assert fig.layout.yaxis.autorange == "reversed"
69+
fig = px.imshow(img_gray, use_binary_string=use_binary_string)
70+
if use_binary_string:
71+
assert fig.layout.yaxis.autorange is None
72+
else:
73+
assert fig.layout.yaxis.autorange == "reversed"
6874

6975

7076
def test_colorscale():
@@ -121,7 +127,7 @@ def test_zmax_floats():
121127
def test_zmin_zmax_range_color():
122128
img = img_gray / 100.0
123129
fig = px.imshow(img)
124-
#assert not (fig.layout.coloraxis.cmin or fig.layout.coloraxis.cmax)
130+
# assert not (fig.layout.coloraxis.cmin or fig.layout.coloraxis.cmax)
125131
fig1 = px.imshow(img, zmin=0.2, zmax=0.8)
126132
fig2 = px.imshow(img, range_color=[0.2, 0.8])
127133
assert fig1 == fig2
@@ -197,4 +203,41 @@ def test_imshow_source():
197203
assert np.all(decoded_img == img_rgb)
198204

199205

200-
# def test_imshow_source_dtype_zmax():
206+
@pytest.mark.parametrize(
207+
"dtype",
208+
[
209+
np.uint8,
210+
np.uint16,
211+
np.int8,
212+
np.int16,
213+
np.int32,
214+
np.int64,
215+
np.float32,
216+
np.float64,
217+
],
218+
)
219+
@pytest.mark.parametrize("contrast_rescaling", ["image", "dtype"])
220+
def test_imshow_source_dtype_zmax(dtype, contrast_rescaling):
221+
img = np.arange(100, dtype=dtype).reshape((10, 10))
222+
fig = px.imshow(img, use_binary_string=True, contrast_rescaling=contrast_rescaling)
223+
if contrast_rescaling == "image":
224+
assert (
225+
np.max(
226+
np.abs(
227+
rescale_intensity(img, in_range="image", out_range=np.uint8)
228+
- decode_image_string(fig.data[0].source)
229+
)
230+
)
231+
< 1
232+
)
233+
else:
234+
if dtype in [np.uint8, np.float32, np.float64]:
235+
assert np.all(img == decode_image_string(fig.data[0].source))
236+
else:
237+
assert (
238+
np.abs(
239+
np.max(decode_image_string(fig.data[0].source))
240+
- 255 * img.max() / np.iinfo(dtype).max
241+
)
242+
< 1
243+
)

0 commit comments

Comments
 (0)