Skip to content

Commit 01184e8

Browse files
authored
Merge pull request #81 from predict-idlab/convert_traces_kwargs
✨ adding `convert_traces_kwargs`
2 parents 0ea320b + 09314f8 commit 01184e8

File tree

4 files changed

+242
-3
lines changed

4 files changed

+242
-3
lines changed

plotly_resampler/figure_resampler/figure_resampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
"",
4040
),
4141
show_mean_aggregation_size: bool = True,
42+
convert_traces_kwargs: dict | None = None,
4243
verbose: bool = False,
4344
):
4445
# Parse the figure input before calling `super`
@@ -67,6 +68,7 @@ def __init__(
6768
default_downsampler,
6869
resampled_trace_prefix_suffix,
6970
show_mean_aggregation_size,
71+
convert_traces_kwargs,
7072
verbose,
7173
)
7274

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
"",
4949
),
5050
show_mean_aggregation_size: bool = True,
51+
convert_traces_kwargs: dict | None = None,
5152
verbose: bool = False,
5253
):
5354
"""Instantiate a resampling data mirror.
@@ -82,6 +83,12 @@ def __init__(
8283
show_mean_aggregation_size: bool, optional
8384
Whether the mean aggregation bin size will be added as a suffix to the trace
8485
its legend-name, by default True.
86+
convert_traces_kwargs: dict, optional
87+
A dict of kwargs that will be passed to the :func:`add_traces` method and
88+
will be used to convert the existing traces. \n
89+
.. note::
90+
This argument is only used when the passed ``figure`` contains data and
91+
``convert_existing_traces`` is set to True.
8592
verbose: bool, optional
8693
Whether some verbose messages will be printed or not, by default False.
8794
@@ -109,9 +116,12 @@ def __init__(
109116
f_._grid_ref = figure._grid_ref
110117
super().__init__(f_)
111118

119+
if convert_traces_kwargs is None:
120+
convert_traces_kwargs = {}
121+
112122
# make sure that the UIDs of these traces do not get adjusted
113123
self._data_validator.set_uid = False
114-
self.add_traces(figure.data)
124+
self.add_traces(figure.data, **convert_traces_kwargs)
115125
else:
116126
super().__init__(figure)
117127
self._data_validator.set_uid = False
@@ -432,7 +442,7 @@ def _get_figure_class(constr: type) -> type:
432442
433443
.. Note::
434444
This method will always return a plotly constructor, even when the given
435-
`constr` is decorated (after executing the ``register_plotly_resampler``
445+
`constr` is decorated (after executing the ``register_plotly_resampler``
436446
function).
437447
438448
Parameters
@@ -952,7 +962,7 @@ def add_traces(
952962
953963
.. note::
954964
Make sure to look at the :func:`add_trace` function for more info about
955-
**speed optimization**, and dealing with not ``high-frequency`` data, but
965+
**speed optimization**, and dealing with not ``high-frequency`` data, but
956966
still want to resample / limit the data to the front-end view.
957967
958968
Parameters

plotly_resampler/figure_resampler/figurewidget_resampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
"",
5050
),
5151
show_mean_aggregation_size: bool = True,
52+
convert_traces_kwargs: dict | None = None,
5253
verbose: bool = False,
5354
):
5455
# Parse the figure input before calling `super`
@@ -71,6 +72,7 @@ def __init__(
7172
default_downsampler,
7273
resampled_trace_prefix_suffix,
7374
show_mean_aggregation_size,
75+
convert_traces_kwargs,
7476
verbose,
7577
)
7678

tests/test_composability.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from random import sample
12
import plotly.graph_objects as go
23
from plotly.subplots import make_subplots
34
from plotly_resampler import FigureResampler, FigureWidgetResampler
@@ -34,6 +35,60 @@ def test_fr_f_scatter_agg(float_series, bool_series, cat_series):
3435
assert trace.uid not in fr_f._hf_data
3536
assert len(trace["y"]) == 10_000
3637

38+
def test_fr_fwr_f_scatter_convert_traces_kwargs(
39+
float_series, bool_series, cat_series
40+
):
41+
base_fig = make_subplots(
42+
rows=2,
43+
cols=2,
44+
specs=[[{}, {}], [{"colspan": 2}, None]],
45+
)
46+
base_fig.add_trace(go.Scatter(y=cat_series), row=1, col=1)
47+
base_fig.add_trace(dict(y=bool_series), row=1, col=2)
48+
base_fig.add_trace(go.Scattergl(y=float_series), row=2, col=1)
49+
50+
fwr = FigureWidgetResampler(base_fig, default_n_shown_samples=10_000)
51+
assert len(fwr.hf_data) == 0
52+
53+
fwr = FigureWidgetResampler(
54+
base_fig,
55+
default_n_shown_samples=10_000,
56+
convert_traces_kwargs=dict(limit_to_views=True),
57+
)
58+
assert len(fwr.hf_data) == 3
59+
60+
n_sample_list = [1010, 1020, 1030]
61+
fwr = FigureWidgetResampler(
62+
base_fig,
63+
default_n_shown_samples=10_000,
64+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
65+
)
66+
assert len(fwr.hf_data) == 3
67+
for i, n_samples in enumerate(n_sample_list):
68+
assert len(fwr.data[i]["y"]) == n_samples
69+
assert len(fwr.hf_data[i]["y"]) == 10_000
70+
71+
# FigureResampler as wrapped class
72+
fr = FigureResampler(base_fig, default_n_shown_samples=10_000)
73+
assert len(fr.hf_data) == 0
74+
75+
fr = FigureResampler(
76+
base_fig,
77+
default_n_shown_samples=10_000,
78+
convert_traces_kwargs=dict(limit_to_views=True),
79+
)
80+
assert len(fr.hf_data) == 3
81+
82+
fr = FigureWidgetResampler(
83+
base_fig,
84+
default_n_shown_samples=10_000,
85+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
86+
)
87+
assert len(fr.hf_data) == 3
88+
for i, n_samples in enumerate(n_sample_list):
89+
assert len(fr.data[i]["y"]) == n_samples
90+
assert len(fwr.hf_data[i]["y"]) == 10_000
91+
3792
def test_fwr_f_scatter_agg(float_series, bool_series, cat_series):
3893
base_fig = make_subplots(
3994
rows=2,
@@ -211,6 +266,62 @@ def test_fr_fw_scatter_agg(float_series, bool_series, cat_series):
211266
assert trace.uid not in fr_fw._hf_data
212267
assert len(trace["y"]) == 10_000
213268

269+
def test_fr_fwr_fw_scatter_convert_traces_kwargs(
270+
float_series, bool_series, cat_series
271+
):
272+
base_fig = go.FigureWidget(
273+
make_subplots(
274+
rows=2,
275+
cols=2,
276+
specs=[[{}, {}], [{"colspan": 2}, None]],
277+
)
278+
)
279+
base_fig.add_trace(go.Scatter(y=cat_series), row=1, col=1)
280+
base_fig.add_trace(dict(y=bool_series), row=1, col=2)
281+
base_fig.add_trace(go.Scattergl(y=float_series), row=2, col=1)
282+
283+
fwr = FigureWidgetResampler(base_fig, default_n_shown_samples=10_000)
284+
assert len(fwr.hf_data) == 0
285+
286+
fwr = FigureWidgetResampler(
287+
base_fig,
288+
default_n_shown_samples=10_000,
289+
convert_traces_kwargs=dict(limit_to_views=True),
290+
)
291+
assert len(fwr.hf_data) == 3
292+
293+
n_sample_list = [1010, 1020, 1030]
294+
fwr = FigureWidgetResampler(
295+
base_fig,
296+
default_n_shown_samples=10_000,
297+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
298+
)
299+
assert len(fwr.hf_data) == 3
300+
for i, n_samples in enumerate(n_sample_list):
301+
assert len(fwr.data[i]["y"]) == n_samples
302+
assert len(fwr.hf_data[i]["y"]) == 10_000
303+
304+
# FigureResampler as wrapped class
305+
fr = FigureResampler(base_fig, default_n_shown_samples=10_000)
306+
assert len(fr.hf_data) == 0
307+
308+
fr = FigureResampler(
309+
base_fig,
310+
default_n_shown_samples=10_000,
311+
convert_traces_kwargs=dict(limit_to_views=True),
312+
)
313+
assert len(fr.hf_data) == 3
314+
315+
fr = FigureWidgetResampler(
316+
base_fig,
317+
default_n_shown_samples=10_000,
318+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
319+
)
320+
assert len(fr.hf_data) == 3
321+
for i, n_samples in enumerate(n_sample_list):
322+
assert len(fr.data[i]["y"]) == n_samples
323+
assert len(fwr.hf_data[i]["y"]) == 10_000
324+
214325
def test_fwr_fw_scatter_agg(float_series, bool_series, cat_series):
215326
base_fig = go.FigureWidget(
216327
make_subplots(
@@ -403,6 +514,63 @@ def test_fr_fr_scatter_agg(float_series, bool_series, cat_series):
403514
for trace in fr_fr.data:
404515
assert len(trace["y"]) == 10_000
405516

517+
def test_fr_fwr_fr_scatter_convert_traces_kwargs(
518+
float_series, bool_series, cat_series
519+
):
520+
base_fig = FigureResampler(
521+
make_subplots(
522+
rows=2,
523+
cols=2,
524+
specs=[[{}, {}], [{"colspan": 2}, None]],
525+
),
526+
default_n_shown_samples=10_000,
527+
)
528+
base_fig.add_trace(go.Scatter(y=cat_series), row=1, col=1)
529+
base_fig.add_trace(dict(y=bool_series), row=1, col=2)
530+
base_fig.add_trace(go.Scattergl(y=float_series), row=2, col=1)
531+
532+
fwr = FigureWidgetResampler(base_fig, default_n_shown_samples=10_000)
533+
assert len(fwr.hf_data) == 0
534+
535+
fwr = FigureWidgetResampler(
536+
base_fig,
537+
default_n_shown_samples=10_000,
538+
convert_traces_kwargs=dict(limit_to_views=True),
539+
)
540+
assert len(fwr.hf_data) == 3
541+
542+
n_sample_list = [1010, 1020, 1030]
543+
fwr = FigureWidgetResampler(
544+
base_fig,
545+
default_n_shown_samples=10_000,
546+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
547+
)
548+
assert len(fwr.hf_data) == 3
549+
for i, n_samples in enumerate(n_sample_list):
550+
assert len(fwr.data[i]["y"]) == n_samples
551+
assert len(fwr.hf_data[i]["y"]) == 10_000
552+
553+
# FigureResampler as wrapped class
554+
fr = FigureResampler(base_fig, default_n_shown_samples=10_000)
555+
assert len(fr.hf_data) == 0
556+
557+
fr = FigureResampler(
558+
base_fig,
559+
default_n_shown_samples=10_000,
560+
convert_traces_kwargs=dict(limit_to_views=True),
561+
)
562+
assert len(fr.hf_data) == 3
563+
564+
fr = FigureWidgetResampler(
565+
base_fig,
566+
default_n_shown_samples=10_000,
567+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
568+
)
569+
assert len(fr.hf_data) == 3
570+
for i, n_samples in enumerate(n_sample_list):
571+
assert len(fr.data[i]["y"]) == n_samples
572+
assert len(fwr.hf_data[i]["y"]) == 10_000
573+
406574
def test_fr_fr_scatter_no_agg_agg(float_series, bool_series, cat_series):
407575
# This initial figure object does not contain any aggregated data as
408576
# default_n_shown samples >= the input data
@@ -780,6 +948,63 @@ def test_fr_fwr_scatter_agg(float_series, bool_series, cat_series):
780948
for trace in fr_fw.data:
781949
assert len(trace["y"]) == 10_000
782950

951+
def test_fr_fwr_fwr_scatter_convert_traces_kwargs(
952+
float_series, bool_series, cat_series
953+
):
954+
base_fig = FigureWidgetResampler(
955+
make_subplots(
956+
rows=2,
957+
cols=2,
958+
specs=[[{}, {}], [{"colspan": 2}, None]],
959+
),
960+
default_n_shown_samples=10_000,
961+
)
962+
base_fig.add_trace(go.Scatter(y=cat_series), row=1, col=1)
963+
base_fig.add_trace(dict(y=bool_series), row=1, col=2)
964+
base_fig.add_trace(go.Scattergl(y=float_series), row=2, col=1)
965+
966+
fwr = FigureWidgetResampler(base_fig, default_n_shown_samples=10_000)
967+
assert len(fwr.hf_data) == 0
968+
969+
fwr = FigureWidgetResampler(
970+
base_fig,
971+
default_n_shown_samples=10_000,
972+
convert_traces_kwargs=dict(limit_to_views=True),
973+
)
974+
assert len(fwr.hf_data) == 3
975+
976+
n_sample_list = [1010, 1020, 1030]
977+
fwr = FigureWidgetResampler(
978+
base_fig,
979+
default_n_shown_samples=10_000,
980+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
981+
)
982+
assert len(fwr.hf_data) == 3
983+
for i, n_samples in enumerate(n_sample_list):
984+
assert len(fwr.data[i]["y"]) == n_samples
985+
assert len(fwr.hf_data[i]["y"]) == 10_000
986+
987+
# FigureResampler as wrapped class
988+
fr = FigureResampler(base_fig, default_n_shown_samples=10_000)
989+
assert len(fr.hf_data) == 0
990+
991+
fr = FigureResampler(
992+
base_fig,
993+
default_n_shown_samples=10_000,
994+
convert_traces_kwargs=dict(limit_to_views=True),
995+
)
996+
assert len(fr.hf_data) == 3
997+
998+
fr = FigureWidgetResampler(
999+
base_fig,
1000+
default_n_shown_samples=10_000,
1001+
convert_traces_kwargs=dict(max_n_samples=n_sample_list),
1002+
)
1003+
assert len(fr.hf_data) == 3
1004+
for i, n_samples in enumerate(n_sample_list):
1005+
assert len(fr.data[i]["y"]) == n_samples
1006+
assert len(fwr.hf_data[i]["y"]) == 10_000
1007+
7831008
def test_fr_fwr_scatter_no_agg_agg(float_series, bool_series, cat_series):
7841009
# This inital figure object does not contain any aggregated data as
7851010
# default_n_shown samples >= the input data

0 commit comments

Comments
 (0)