|
1 | 1 | import plotly.graph_objs as go
|
2 | 2 | from _plotly_utils.basevalidators import ColorscaleValidator
|
3 | 3 | from ._core import apply_default_cascade
|
4 |
| -import numpy as np |
5 | 4 | from PIL import Image
|
6 | 5 | from io import BytesIO
|
7 | 6 | import base64
|
8 | 7 | from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
|
9 | 8 | import pandas as pd
|
10 |
| - |
| 9 | +import png |
| 10 | +import numpy as np |
11 | 11 | try:
|
12 | 12 | import xarray
|
13 | 13 |
|
|
18 | 18 | _float_types = []
|
19 | 19 |
|
20 | 20 |
|
21 |
| -def _array_to_b64str(img, ext="png"): |
22 |
| - pil_img = Image.fromarray(img) |
23 |
| - if ext == "jpg": |
24 |
| - ext = "jpeg" |
25 |
| - buff = BytesIO() |
26 |
| - pil_img.save(buff, format=ext) |
27 |
| - if ext == "png": |
28 |
| - prefix = b"data:image/png;base64," |
29 |
| - elif ext == "jpeg": |
30 |
| - prefix = b"data:image/jpeg;base64," |
| 21 | +def _array_to_b64str(img, backend='pil', compression=4): |
| 22 | + if img.ndim == 2: |
| 23 | + mode = 'L' |
| 24 | + elif img.ndim == 3 and img.shape[-1] == 3: |
| 25 | + mode = 'RGB' |
| 26 | + elif img.ndim == 3 and img.shape[-1] == 4: |
| 27 | + mode = 'RGBA' |
31 | 28 | else:
|
32 | 29 | raise ValueError(
|
33 |
| - "accepted image formats are 'png' and 'jpeg' but %s was passed" % format |
34 |
| - ) |
35 |
| - image_string = (prefix + base64.b64encode(buff.getvalue())).decode("utf-8") |
36 |
| - return image_string |
| 30 | + "Invalid image shape" |
| 31 | + ) |
| 32 | + if backend=='png': |
| 33 | + ndim = img.ndim |
| 34 | + sh = img.shape |
| 35 | + if ndim == 3: |
| 36 | + img = img.reshape((sh[0], sh[1] * sh[2])) |
| 37 | + w = png.Writer(sh[1], sh[0], greyscale=(ndim == 2), compression=compression) |
| 38 | + img_png = png.from_array(img, mode=mode) |
| 39 | + prefix = "data:image/png;base64," |
| 40 | + with BytesIO() as stream: |
| 41 | + w.write(stream, img_png.rows) |
| 42 | + base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8") |
| 43 | + else: |
| 44 | + pil_img = Image.fromarray(img) |
| 45 | + prefix = "data:image/png;base64," |
| 46 | + with BytesIO() as stream: |
| 47 | + pil_img.save(stream, format='png', compress_level=compression) |
| 48 | + base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8") |
| 49 | + return base64_string |
37 | 50 |
|
38 | 51 |
|
39 | 52 | def _vectorize_zvalue(z):
|
|
0 commit comments