Skip to content

Commit 355fbec

Browse files
committed
💪 making orjson serialization more robust, see #118
1 parent 5d34b62 commit 355fbec

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from uuid import uuid4
1919
from collections import namedtuple
2020

21+
import orjson
2122
import dash
2223
import numpy as np
2324
import pandas as pd
@@ -291,7 +292,7 @@ def _check_update_trace_data(
291292
s_res: pd.Series = downsampler.aggregate(
292293
hf_series, hf_trace_data["max_n_samples"]
293294
)
294-
trace["x"] = s_res.index
295+
trace["x"] = self._parse_dtype_orjson(s_res.index)
295296
trace["y"] = s_res.values
296297
# todo -> first draft & not MP safe
297298

@@ -700,9 +701,14 @@ def _parse_get_trace_props(
700701
except ValueError:
701702
hf_y = hf_y.astype("str")
702703

703-
# orjson encoding doesn't like to encode with uint8 & uint16 dtype
704-
if str(hf_y.dtype) in ["uint8", "uint16"]:
705-
hf_y = hf_y.astype("uint32")
704+
msg = (
705+
"Plotly-Resampler its aggregator functions do not support the float128"
706+
+ " dtype, so please consider casting your data to float64\n."
707+
+ " If you have an eligible usecase where float128 still is necessary,"
708+
+ " please consider making an issue on GitHub."
709+
)
710+
assert hf_x.dtype != np.float128, msg
711+
assert hf_y.dtype != np.float128, msg
706712

707713
assert len(hf_x) == len(hf_y), "x and y have different length!"
708714
else:
@@ -1152,8 +1158,7 @@ def replace(self, figure: go.Figure, convert_existing_traces: bool = True):
11521158
)
11531159

11541160
def construct_update_data(
1155-
self,
1156-
relayout_data: dict
1161+
self, relayout_data: dict
11571162
) -> Union[List[dict], dash.no_update]:
11581163
"""Construct the to-be-updated front-end data, based on the layout change.
11591164
@@ -1261,6 +1266,24 @@ def construct_update_data(
12611266
layout_traces_list.append(trace_reduced)
12621267
return layout_traces_list
12631268

1269+
@staticmethod
1270+
def _parse_dtype_orjson(series: np.ndarray) -> np.ndarray:
1271+
"""Verify the orjson compatibility of the series and convert it if needed."""
1272+
# NOTE:
1273+
# * float16 and float128 aren't supported with latest orjson versions (3.8.1)
1274+
# * this method assumes that the it will not get a float128 series
1275+
if series.dtype in [np.float16]:
1276+
return series.astype(np.float32)
1277+
1278+
# orjson < 3.8.0 encoding cannot encode with int16 & uint16 dtype
1279+
elif series.dtype in [np.int16, np.uint16]:
1280+
major_v, minor_v = list(map(int, orjson.__version__.split(".")))[:2]
1281+
if major_v < 3 or major_v == 3 and minor_v < 8:
1282+
if series.dtype == np.uint16:
1283+
return series.astype("uint32")
1284+
return series.astype(np.int32)
1285+
return series
1286+
12641287
@staticmethod
12651288
def _re_matches(regex: re.Pattern, strings: Iterable[str]) -> List[str]:
12661289
"""Returns all the items in ``strings`` which regex.match(es) ``regex``."""

tests/test_figure_resampler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,43 @@ def test_add_trace_not_resampling(float_series):
101101
hf_hovertext="hovertext",
102102
)
103103

104+
def test_various_dtypes(float_series):
105+
valid_dtype_list = [
106+
np.bool_,
107+
# ---- uints
108+
np.uint8,
109+
np.uint16,
110+
np.uint32,
111+
np.uint64,
112+
# -------- ints
113+
np.int8,
114+
np.int16,
115+
np.int32,
116+
np.int64,
117+
# -------- floats
118+
np.float16,
119+
np.float32,
120+
np.float64,
121+
]
122+
for dtype in valid_dtype_list:
123+
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
124+
fig.add_trace(
125+
go.Scatter(name="float_series"),
126+
hf_x=float_series.index,
127+
hf_y=float_series.astype(dtype),
128+
limit_to_view=True,
129+
)
130+
131+
invalid_dtype_list = [ np.float128 ]
132+
for invalid_dtype in invalid_dtype_list:
133+
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
134+
with pytest.raises(AssertionError):
135+
fig.add_trace(
136+
go.Scatter(name="float_series"),
137+
hf_x=float_series.index,
138+
hf_y=float_series.astype(invalid_dtype),
139+
limit_to_view=True,
140+
)
104141

105142
def test_add_scatter_trace_no_data():
106143
fig = FigureResampler(default_n_shown_samples=1000)

0 commit comments

Comments
 (0)