Skip to content

Commit 066564e

Browse files
committed
Update logic to be recursive
1 parent 61e9178 commit 066564e

File tree

2 files changed

+107
-94
lines changed

2 files changed

+107
-94
lines changed

packages/python/plotly/_plotly_utils/basevalidators.py

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -51,83 +51,6 @@ def to_scalar_or_list(v):
5151
return v
5252

5353

54-
plotlyjsShortTypes = {
55-
"int8": "i1",
56-
"uint8": "u1",
57-
"int16": "i2",
58-
"uint16": "u2",
59-
"int32": "i4",
60-
"uint32": "u4",
61-
"float32": "f4",
62-
"float64": "f8",
63-
}
64-
65-
int8min = -128
66-
int8max = 127
67-
int16min = -32768
68-
int16max = 32767
69-
int32min = -2147483648
70-
int32max = 2147483647
71-
72-
uint8max = 255
73-
uint16max = 65535
74-
uint32max = 4294967295
75-
76-
77-
def to_typed_array_spec(v):
78-
"""
79-
Convert numpy array to plotly.js typed array spec
80-
If not possible return the original value
81-
"""
82-
v = copy_to_readonly_numpy_array(v)
83-
84-
np = get_module("numpy", should_load=False)
85-
if not isinstance(v, np.ndarray):
86-
return v
87-
88-
dtype = str(v.dtype)
89-
90-
# convert default Big Ints until we could support them in plotly.js
91-
if dtype == "int64":
92-
max = v.max()
93-
min = v.min()
94-
if max <= int8max and min >= int8min:
95-
v = v.astype("int8")
96-
elif max <= int16max and min >= int16min:
97-
v = v.astype("int16")
98-
elif max <= int32max and min >= int32min:
99-
v = v.astype("int32")
100-
else:
101-
return v
102-
103-
elif dtype == "uint64":
104-
max = v.max()
105-
min = v.min()
106-
if max <= uint8max and min >= 0:
107-
v = v.astype("uint8")
108-
elif max <= uint16max and min >= 0:
109-
v = v.astype("uint16")
110-
elif max <= uint32max and min >= 0:
111-
v = v.astype("uint32")
112-
else:
113-
return v
114-
115-
dtype = str(v.dtype)
116-
117-
if dtype in plotlyjsShortTypes:
118-
arrObj = {
119-
"dtype": plotlyjsShortTypes[dtype],
120-
"bdata": base64.b64encode(v).decode("ascii"),
121-
}
122-
123-
if v.ndim > 1:
124-
arrObj["shape"] = str(v.shape)[1:-1]
125-
126-
return arrObj
127-
128-
return v
129-
130-
13154
def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
13255
"""
13356
Convert an array-like value into a read-only numpy array
@@ -292,15 +215,6 @@ def is_typed_array_spec(v):
292215
return isinstance(v, dict) and "bdata" in v and "dtype" in v
293216

294217

295-
def has_skipped_key(all_parent_keys):
296-
"""
297-
Return whether any keys in the parent hierarchy are in the list of keys that
298-
are skipped for conversion to the typed array spec
299-
"""
300-
skipped_keys = ["geojson", "layer", "range"]
301-
return any(skipped_key in all_parent_keys for skipped_key in skipped_keys)
302-
303-
304218
def is_none_or_typed_array_spec(v):
305219
return v is None or is_typed_array_spec(v)
306220

@@ -500,8 +414,6 @@ def description(self):
500414
def validate_coerce(self, v):
501415
if is_none_or_typed_array_spec(v):
502416
pass
503-
elif has_skipped_key(self.parent_name):
504-
v = to_scalar_or_list(v)
505417
elif is_homogeneous_array(v):
506418
v = copy_to_readonly_numpy_array(v)
507419
elif is_simple_array(v):

packages/python/plotly/plotly/io/_utils.py

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,113 @@
1-
from _plotly_utils.basevalidators import is_homogeneous_array, to_typed_array_spec
1+
import base64
2+
from _plotly_utils.basevalidators import (
3+
copy_to_readonly_numpy_array,
4+
is_homogeneous_array,
5+
to_typed_array_spec,
6+
)
7+
from packages.python.plotly._plotly_utils.optional_imports import get_module
28
import plotly
39
import plotly.graph_objs as go
410
from plotly.offline import get_plotlyjs_version
511

12+
int8min = -128
13+
int8max = 127
14+
int16min = -32768
15+
int16max = 32767
16+
int32min = -2147483648
17+
int32max = 2147483647
18+
19+
uint8max = 255
20+
uint16max = 65535
21+
uint32max = 4294967295
22+
23+
plotlyjsShortTypes = {
24+
"int8": "i1",
25+
"uint8": "u1",
26+
"int16": "i2",
27+
"uint16": "u2",
28+
"int32": "i4",
29+
"uint32": "u4",
30+
"float32": "f4",
31+
"float64": "f8",
32+
}
33+
34+
35+
def to_typed_array_spec(v):
36+
"""
37+
Convert numpy array to plotly.js typed array spec
38+
If not possible return the original value
39+
"""
40+
v = copy_to_readonly_numpy_array(v)
41+
42+
np = get_module("numpy", should_load=False)
43+
if not isinstance(v, np.ndarray):
44+
return v
45+
46+
dtype = str(v.dtype)
47+
48+
# convert default Big Ints until we could support them in plotly.js
49+
if dtype == "int64":
50+
max = v.max()
51+
min = v.min()
52+
if max <= int8max and min >= int8min:
53+
v = v.astype("int8")
54+
elif max <= int16max and min >= int16min:
55+
v = v.astype("int16")
56+
elif max <= int32max and min >= int32min:
57+
v = v.astype("int32")
58+
else:
59+
return v
60+
61+
elif dtype == "uint64":
62+
max = v.max()
63+
min = v.min()
64+
if max <= uint8max and min >= 0:
65+
v = v.astype("uint8")
66+
elif max <= uint16max and min >= 0:
67+
v = v.astype("uint16")
68+
elif max <= uint32max and min >= 0:
69+
v = v.astype("uint32")
70+
else:
71+
return v
72+
73+
dtype = str(v.dtype)
74+
75+
if dtype in plotlyjsShortTypes:
76+
arrObj = {
77+
"dtype": plotlyjsShortTypes[dtype],
78+
"bdata": base64.b64encode(v).decode("ascii"),
79+
}
80+
81+
if v.ndim > 1:
82+
arrObj["shape"] = str(v.shape)[1:-1]
83+
84+
return arrObj
85+
86+
return v
87+
88+
89+
def is_skipped_key(key):
90+
"""
91+
Return whether any keys in the parent hierarchy are in the list of keys that
92+
are skipped for conversion to the typed array spec
93+
"""
94+
skipped_keys = ["geojson", "layer", "range"]
95+
return any(skipped_key in key for skipped_key in skipped_keys)
96+
97+
98+
def convert_to_base64(obj):
99+
if isinstance(obj, dict):
100+
for key, value in obj.items():
101+
if is_skipped_key(key):
102+
continue
103+
elif is_homogeneous_array(value):
104+
obj[key] = to_typed_array_spec(value)
105+
else:
106+
convert_to_base64(value)
107+
elif isinstance(obj, list) or isinstance(obj, tuple):
108+
for i, value in enumerate(obj):
109+
convert_to_base64(value)
110+
6111

7112
def validate_coerce_fig_to_dict(fig, validate):
8113
from plotly.basedatatypes import BaseFigure
@@ -27,11 +132,7 @@ def validate_coerce_fig_to_dict(fig, validate):
27132
)
28133

29134
# Add base64 conversion before sending to the front-end
30-
for trace in fig_dict["data"]:
31-
for key, value in trace.items():
32-
if is_homogeneous_array(value):
33-
print("to typed array: key:", key, "value:", value)
34-
trace[key] = to_typed_array_spec(value)
135+
convert_to_base64(fig_dict)
35136

36137
return fig_dict
37138

0 commit comments

Comments
 (0)