1
1
import plotly .graph_objs as go
2
2
from _plotly_utils .basevalidators import ColorscaleValidator
3
3
from ._core import apply_default_cascade
4
- from PIL import Image
5
4
from io import BytesIO
6
5
import base64
7
6
from .imshow_utils import rescale_intensity , _integer_ranges , _integer_types
8
7
import pandas as pd
9
8
import png
10
9
import numpy as np
10
+
11
11
try :
12
12
import xarray
13
13
14
14
xarray_imported = True
15
15
except ImportError :
16
16
xarray_imported = False
17
+ try :
18
+ from PIL import Image
19
+
20
+ pil_imported = True
21
+ except ImportError :
22
+ pil_imported = False
17
23
18
24
_float_types = []
19
25
20
26
21
- def _array_to_b64str (img , backend = 'pil' , compression = 4 ):
27
+ def _array_to_b64str (img , backend = "pil" , compression = 4 ):
28
+ """Converts a numpy array of uint8 into a base64 png string.
29
+
30
+ Parameters
31
+ ----------
32
+ img: ndarray of uint8
33
+ array image
34
+ backend: str
35
+ 'auto', 'pil' or 'pypng'. If 'auto', Pillow is used if installed,
36
+ otherwise pypng.
37
+ compression: int, between 0 and 9
38
+ compression level to be passed to the backend
39
+ """
40
+ # PIL and pypng error messages are quite obscure so we catch invalid compression values
41
+ if compression < 0 or compression > 9 :
42
+ raise ValueError ("compression level must be between 0 and 9." )
22
43
if img .ndim == 2 :
23
- mode = 'L'
44
+ mode = "L"
24
45
elif img .ndim == 3 and img .shape [- 1 ] == 3 :
25
- mode = ' RGB'
46
+ mode = " RGB"
26
47
elif img .ndim == 3 and img .shape [- 1 ] == 4 :
27
- mode = ' RGBA'
48
+ mode = " RGBA"
28
49
else :
29
- raise ValueError (
30
- "Invalid image shape"
31
- )
32
- if backend == 'png' :
50
+ raise ValueError ("Invalid image shape" )
51
+ if backend == "auto" :
52
+ backend = "pil" if pil_imported else "pypng"
53
+ if backend == "pypng" :
33
54
ndim = img .ndim
34
55
sh = img .shape
35
56
if ndim == 3 :
@@ -40,11 +61,16 @@ def _array_to_b64str(img, backend='pil', compression=4):
40
61
with BytesIO () as stream :
41
62
w .write (stream , img_png .rows )
42
63
base64_string = prefix + base64 .b64encode (stream .getvalue ()).decode ("utf-8" )
43
- else :
64
+ else : # pil
65
+ if not pil_imported :
66
+ raise ImportError (
67
+ "pillow needs to be installed to use `backend='pil'. Please"
68
+ "install pillow or use `backend='pypng'."
69
+ )
44
70
pil_img = Image .fromarray (img )
45
71
prefix = "data:image/png;base64,"
46
72
with BytesIO () as stream :
47
- pil_img .save (stream , format = ' png' , compress_level = compression )
73
+ pil_img .save (stream , format = " png" , compress_level = compression )
48
74
base64_string = prefix + base64 .b64encode (stream .getvalue ()).decode ("utf-8" )
49
75
return base64_string
50
76
@@ -100,8 +126,10 @@ def imshow(
100
126
width = None ,
101
127
height = None ,
102
128
aspect = None ,
103
- use_binary_string = None ,
104
129
contrast_rescaling = None ,
130
+ binary_string = None ,
131
+ binary_backend = "auto" ,
132
+ binary_compression_level = 4 ,
105
133
):
106
134
"""
107
135
Display an image, i.e. data on a 2D regular raster.
@@ -177,7 +205,13 @@ def imshow(
177
205
- if None, 'equal' is used for numpy arrays and 'auto' for xarrays
178
206
(which have typically heterogeneous coordinates)
179
207
180
- use_binary_string: bool, default None
208
+ contrast_rescaling: 'minmax', 'infer', or None
209
+ how to determine data values corresponding to the bounds of the color
210
+ range, when zmin or zmax are not passed. If `minmax`, the min and max
211
+ values of the image are used. If `infer`, a heuristic based on the image
212
+ data type is used.
213
+
214
+ binary_string: bool, default None
181
215
if True, the image data are first rescaled and encoded as uint8 and
182
216
then passed to plotly.js as a b64 PNG string. If False, data are passed
183
217
unchanged as a numerical array. Setting to True may lead to performance
@@ -187,11 +221,18 @@ def imshow(
187
221
represented as grayscale and with no colorbar if use_binary_string is
188
222
True.
189
223
190
- contrast_rescaling: 'minmax', 'infer', or None
191
- how to determine data values corresponding to the bounds of the color
192
- range, when zmin or zmax are not passed. If `minmax`, the min and max
193
- values of the image are used. If `infer`, a heuristic based on the image
194
- data type is used.
224
+ binary_backend: str, 'auto' (default), 'pil' or 'pypng'
225
+ Third-party package for the transformation of numpy arrays to
226
+ png b64 strings. If 'auto', Pillow is used if installed, otherwise
227
+ pypng.
228
+
229
+ binary_compression_level: int, between 0 and 9 (default 4)
230
+ png compression level to be passed to the backend when transforming an
231
+ array to a png b64 string. Increasing `binary_compression` decreases the
232
+ size of the png string, but the compression step takes more time. For most
233
+ images it is not worth using levels greater than 5, but it's possible to
234
+ test `len(fig.data[0].source)` and to time the execution of `imshow` to
235
+ tune the level of compression. 0 means no compression (not recommended).
195
236
196
237
Returns
197
238
-------
@@ -217,7 +258,7 @@ def imshow(
217
258
labels = labels .copy ()
218
259
# ----- Define x and y, set labels if img is an xarray -------------------
219
260
if xarray_imported and isinstance (img , xarray .DataArray ):
220
- if use_binary_string :
261
+ if binary_string :
221
262
raise ValueError (
222
263
"It is not possible to use binary image strings for xarrays."
223
264
"Please pass your data as a numpy array instead using"
@@ -262,24 +303,24 @@ def imshow(
262
303
if aspect is None :
263
304
aspect = "equal"
264
305
265
- # Set the value of use_binary_string
306
+ # Set the value of binary_string
266
307
if isinstance (img , pd .DataFrame ):
267
- if use_binary_string :
308
+ if binary_string :
268
309
raise ValueError ("Binary strings cannot be used with pandas arrays" )
269
310
has_nans = True
270
311
else :
271
312
has_nans = np .any (np .isnan (img ))
272
- if has_nans and use_binary_string :
313
+ if has_nans and binary_string :
273
314
raise ValueError (
274
315
"Binary strings cannot be used with arrays containing NaNs"
275
316
)
276
317
277
318
# --------------- Starting from here img is always a numpy array --------
278
319
img = np .asanyarray (img )
279
320
280
- # Default behaviour of use_binary_string : True for RGB images, False for 2D
281
- if use_binary_string is None :
282
- use_binary_string = img .ndim >= 3 and not has_nans
321
+ # Default behaviour of binary_string : True for RGB images, False for 2D
322
+ if binary_string is None :
323
+ binary_string = img .ndim >= 3 and not has_nans
283
324
284
325
# Cast bools to uint8 (also one byte)
285
326
if img .dtype == np .bool :
@@ -292,18 +333,18 @@ def imshow(
292
333
contrast_rescaling = "minmax" if img .ndim == 2 else "infer"
293
334
294
335
if contrast_rescaling == "minmax" :
295
- if (zmin is not None or use_binary_string ) and zmax is None :
336
+ if (zmin is not None or binary_string ) and zmax is None :
296
337
zmax = img .max ()
297
- if (zmax is not None or use_binary_string ) and zmin is None :
338
+ if (zmax is not None or binary_string ) and zmin is None :
298
339
zmin = img .min ()
299
340
else :
300
- if zmax is None and img .dtype is not np .uint8 :
341
+ if zmax is None and ( img .dtype is not np .uint8 or img . ndim == 2 ) :
301
342
zmax = _infer_zmax_from_type (img )
302
343
if zmin is None :
303
344
zmin = 0
304
345
305
- # For 2d data, use Heatmap trace, unless use_binary_string is True
306
- if img .ndim == 2 and not use_binary_string :
346
+ # For 2d data, use Heatmap trace, unless binary_string is True
347
+ if img .ndim == 2 and not binary_string :
307
348
if y is not None and img .shape [0 ] != len (y ):
308
349
raise ValueError (
309
350
"The length of the y vector must match the length of the first "
@@ -333,13 +374,9 @@ def imshow(
333
374
layout ["coloraxis1" ]["colorbar" ] = dict (title_text = labels ["color" ])
334
375
335
376
# For 2D+RGB data, use Image trace
336
- elif (
337
- img .ndim == 3
338
- and img .shape [- 1 ] in [3 , 4 ]
339
- or (img .ndim == 2 and use_binary_string )
340
- ):
377
+ elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ] or (img .ndim == 2 and binary_string ):
341
378
zmin , zmax = _vectorize_zvalue (zmin ), _vectorize_zvalue (zmax )
342
- if use_binary_string :
379
+ if binary_string :
343
380
if img .ndim == 2 :
344
381
img_rescaled = rescale_intensity (
345
382
img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
@@ -357,7 +394,11 @@ def imshow(
357
394
)
358
395
if origin == "lower" :
359
396
img_rescaled = img_rescaled [::- 1 ]
360
- img_str = _array_to_b64str (img_rescaled )
397
+ img_str = _array_to_b64str (
398
+ img_rescaled ,
399
+ backend = binary_backend ,
400
+ compression = binary_compression_level ,
401
+ )
361
402
trace = go .Image (source = img_str )
362
403
else :
363
404
trace = go .Image (z = img , zmin = zmin , zmax = zmax )
@@ -380,8 +421,13 @@ def imshow(
380
421
layout_patch ["margin" ] = {"t" : 60 }
381
422
fig = go .Figure (data = trace , layout = layout )
382
423
fig .update_layout (layout_patch )
383
- # does not work yet, Antoine working on it
384
- hover_name = "z" if not use_binary_string else "colorLabel"
424
+ # Hover name, z or color
425
+ if trace ["type" ] == "heatmap" :
426
+ hover_name = "z"
427
+ elif img .ndim == 2 :
428
+ hover_name = "color[0]"
429
+ else :
430
+ hover_name = "color"
385
431
fig .update_traces (
386
432
hovertemplate = "%s: %%{x}<br>%s: %%{y}<br>%s: %%{%s}<extra></extra>"
387
433
% (
0 commit comments