Skip to content

Commit d3d88b8

Browse files
committed
different backends for b64 creation
1 parent 81ce9f6 commit d3d88b8

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import plotly.graph_objs as go
22
from _plotly_utils.basevalidators import ColorscaleValidator
33
from ._core import apply_default_cascade
4-
import numpy as np
54
from PIL import Image
65
from io import BytesIO
76
import base64
87
from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types
98
import pandas as pd
10-
9+
import png
10+
import numpy as np
1111
try:
1212
import xarray
1313

@@ -18,22 +18,35 @@
1818
_float_types = []
1919

2020

21-
def _array_to_b64str(img, ext="png"):
22-
pil_img = Image.fromarray(img)
23-
if ext == "jpg":
24-
ext = "jpeg"
25-
buff = BytesIO()
26-
pil_img.save(buff, format=ext)
27-
if ext == "png":
28-
prefix = b"data:image/png;base64,"
29-
elif ext == "jpeg":
30-
prefix = b"data:image/jpeg;base64,"
21+
def _array_to_b64str(img, backend='pil', compression=4):
22+
if img.ndim == 2:
23+
mode = 'L'
24+
elif img.ndim == 3 and img.shape[-1] == 3:
25+
mode = 'RGB'
26+
elif img.ndim == 3 and img.shape[-1] == 4:
27+
mode = 'RGBA'
3128
else:
3229
raise ValueError(
33-
"accepted image formats are 'png' and 'jpeg' but %s was passed" % format
34-
)
35-
image_string = (prefix + base64.b64encode(buff.getvalue())).decode("utf-8")
36-
return image_string
30+
"Invalid image shape"
31+
)
32+
if backend=='png':
33+
ndim = img.ndim
34+
sh = img.shape
35+
if ndim == 3:
36+
img = img.reshape((sh[0], sh[1] * sh[2]))
37+
w = png.Writer(sh[1], sh[0], greyscale=(ndim == 2), compression=compression)
38+
img_png = png.from_array(img, mode=mode)
39+
prefix = "data:image/png;base64,"
40+
with BytesIO() as stream:
41+
w.write(stream, img_png.rows)
42+
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
43+
else:
44+
pil_img = Image.fromarray(img)
45+
prefix = "data:image/png;base64,"
46+
with BytesIO() as stream:
47+
pil_img.save(stream, format='png', compress_level=compression)
48+
base64_string = prefix + base64.b64encode(stream.getvalue()).decode("utf-8")
49+
return base64_string
3750

3851

3952
def _vectorize_zvalue(z):

0 commit comments

Comments
 (0)