Skip to content

Commit c262153

Browse files
committed
🧹 only parse f16 & update tests
1 parent 815803f commit c262153

File tree

2 files changed

+14
-25
lines changed

2 files changed

+14
-25
lines changed

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from uuid import uuid4
1919
from collections import namedtuple
2020

21-
import orjson
2221
import dash
2322
import numpy as np
2423
import pandas as pd
@@ -293,6 +292,7 @@ def _check_update_trace_data(
293292
hf_series, hf_trace_data["max_n_samples"]
294293
)
295294
# Also parse the data types to an orjson compatible format
295+
# Note this can be removed once orjson supports f16
296296
trace["x"] = self._parse_dtype_orjson(s_res.index)
297297
trace["y"] = self._parse_dtype_orjson(s_res.values)
298298
# todo -> first draft & not MP safe
@@ -702,15 +702,6 @@ def _parse_get_trace_props(
702702
except ValueError:
703703
hf_y = hf_y.astype("str")
704704

705-
msg = (
706-
"Plotly-Resampler its aggregator functions do not support the float128"
707-
+ " dtype, so please consider casting your data to float64\n."
708-
+ " If you have an eligible usecase where float128 still is necessary,"
709-
+ " please consider making an issue on GitHub."
710-
)
711-
assert hf_x.dtype != np.float128, msg
712-
assert hf_y.dtype != np.float128, msg
713-
714705
assert len(hf_x) == len(hf_y), "x and y have different length!"
715706
else:
716707
self._print(f"trace {trace['type']} is not a high-frequency trace")
@@ -1296,16 +1287,9 @@ def _parse_dtype_orjson(series: np.ndarray) -> np.ndarray:
12961287
# NOTE:
12971288
# * float16 and float128 aren't supported with latest orjson versions (3.8.1)
12981289
# * this method assumes that the it will not get a float128 series
1290+
# -> this method can be removed if orjson supports float16
12991291
if series.dtype in [np.float16]:
13001292
return series.astype(np.float32)
1301-
1302-
# orjson < 3.8.0 encoding cannot encode with int16 & uint16 dtype
1303-
elif series.dtype in [np.int16, np.uint16]:
1304-
major_v, minor_v = list(map(int, orjson.__version__.split(".")))[:2]
1305-
if major_v < 3 or major_v == 3 and minor_v < 8:
1306-
if series.dtype == np.uint16:
1307-
return series.astype("uint32")
1308-
return series.astype(np.int32)
13091293
return series
13101294

13111295
@staticmethod

tests/test_figure_resampler.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def test_add_trace_not_resampling(float_series):
102102
)
103103

104104
def test_various_dtypes(float_series):
105+
# List of dtypes supported by orjson >= 3.8
105106
valid_dtype_list = [
106107
np.bool_,
107108
# ---- uints
@@ -115,29 +116,33 @@ def test_various_dtypes(float_series):
115116
np.int32,
116117
np.int64,
117118
# -------- floats
118-
np.float16,
119+
np.float16, # currently not supported by orjson
119120
np.float32,
120121
np.float64,
121122
]
122123
for dtype in valid_dtype_list:
123124
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
125+
# nb. datapoints > default_n_shown_samples
124126
fig.add_trace(
125127
go.Scatter(name="float_series"),
126128
hf_x=float_series.index,
127129
hf_y=float_series.astype(dtype),
128-
limit_to_view=True,
129130
)
131+
fig.full_figure_for_development()
130132

131-
invalid_dtype_list = [ np.float128 ]
133+
# List of dtypes not supported by orjson >= 3.8
134+
invalid_dtype_list = [ np.float16 ]
132135
for invalid_dtype in invalid_dtype_list:
133136
fig = FigureResampler(go.Figure(), default_n_shown_samples=1000)
134-
with pytest.raises(AssertionError):
137+
# nb. datapoints < default_n_shown_samples
138+
with pytest.raises(TypeError):
139+
# if this test fails -> orjson supports f16 => remove casting frome code
135140
fig.add_trace(
136141
go.Scatter(name="float_series"),
137-
hf_x=float_series.index,
138-
hf_y=float_series.astype(invalid_dtype),
139-
limit_to_view=True,
142+
hf_x=float_series.index[:500],
143+
hf_y=float_series.astype(invalid_dtype)[:500],
140144
)
145+
fig.full_figure_for_development()
141146

142147
def test_max_n_samples(float_series):
143148
s = float_series[:5000]

0 commit comments

Comments
 (0)