Skip to content

Commit 2b7f2e7

Browse files
authored
Merge pull request #72 from predict-idlab/compose_figs
🗳️ Compose figs
2 parents 52607dc + 9b95694 commit 2b7f2e7

File tree

10 files changed

+1850
-79
lines changed

10 files changed

+1850
-79
lines changed

plotly_resampler/figure_resampler/figure_resampler.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717
import plotly.graph_objects as go
1818
from dash import Dash
1919
from jupyter_dash import JupyterDash
20+
from plotly.basedatatypes import BaseFigure
2021
from trace_updater import TraceUpdater
2122

22-
from .figure_resampler_interface import AbstractFigureAggregator
23-
from .utils import is_figure
2423
from ..aggregation import AbstractSeriesAggregator, EfficientLTTB
24+
from .figure_resampler_interface import AbstractFigureAggregator
25+
from .utils import is_figure, is_fr
2526

2627

2728
class FigureResampler(AbstractFigureAggregator, go.Figure):
2829
"""Data aggregation functionality for ``go.Figures``."""
2930

3031
def __init__(
3132
self,
32-
figure: go.Figure | dict = None,
33+
figure: BaseFigure | dict = None,
3334
convert_existing_traces: bool = True,
3435
default_n_shown_samples: int = 1000,
3536
default_downsampler: AbstractSeriesAggregator = EfficientLTTB(),
@@ -40,14 +41,27 @@ def __init__(
4041
show_mean_aggregation_size: bool = True,
4142
verbose: bool = False,
4243
):
43-
if not is_figure(figure): # TODO: does this make sense?
44-
figure = go.Figure(figure)
45-
elif isinstance(figure, FigureResampler):
46-
print("passing") # TODO make composable
47-
pass
44+
# Parse the figure input before calling `super`
45+
if is_figure(figure) and not is_fr(figure): # go.Figure
46+
# Base case, the figure does not need to be adjusted
47+
f = figure
48+
else:
49+
# Create a new figure object and make sure that the trace uid will not get
50+
# adjusted when they are added.
51+
f = go.Figure()
52+
f._data_validator.set_uid = False
53+
54+
if isinstance(figure, BaseFigure): # go.FigureWidget or AbstractFigureAggregator
55+
# A base figure object, we first copy the layout and grid ref
56+
f.layout = figure.layout
57+
f._grid_ref = figure._grid_ref
58+
f.add_traces(figure.data)
59+
elif isinstance(figure, (dict, list)):
60+
# A single trace dict or a list of traces
61+
f.add_traces(figure)
4862

4963
super().__init__(
50-
figure,
64+
f,
5165
convert_existing_traces,
5266
default_n_shown_samples,
5367
default_downsampler,
@@ -56,6 +70,23 @@ def __init__(
5670
verbose,
5771
)
5872

73+
if isinstance(figure, AbstractFigureAggregator):
74+
# Copy the `_hf_data` if the previous figure was an AbstractFigureAggregator
75+
# and adjust the default `max_n_samples` and `downsampler`
76+
self._hf_data.update(
77+
self._copy_hf_data(figure._hf_data, adjust_default_values=True)
78+
)
79+
80+
# Note: This hack ensures that the this figure object initially uses
81+
# data of the whole view. More concretely; we create a dict
82+
# serialization figure and adjust the hf-traces to the whole view
83+
# with the check-update method (by passing no range / filter args)
84+
with self.batch_update():
85+
graph_dict: dict = self._get_current_graph()
86+
update_indices = self._check_update_figure_dict(graph_dict)
87+
for idx in update_indices:
88+
self.data[idx].update(graph_dict["data"][idx])
89+
5990
# The FigureResampler needs a dash app
6091
self._app: JupyterDash | Dash | None = None
6192
self._port: int | None = None

plotly_resampler/figure_resampler/figure_resampler_interface.py

Lines changed: 118 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,29 @@ def __init__(
105105
f_._grid_ref = figure._grid_ref
106106
super().__init__(f_)
107107

108-
for trace in figure.data:
109-
self.add_trace(trace)
108+
# make sure that the UIDs of these traces do not get adjusted
109+
self._data_validator.set_uid = False
110+
self.add_traces(figure.data)
110111
else:
111112
super().__init__(figure)
112-
self._data_validator.set_uid = False
113+
self._data_validator.set_uid = False
114+
115+
# A list of al xaxis and yaxis string names
116+
# e.g., "xaxis", "xaxis2", "xaxis3", .... for _xaxis_list
117+
self._xaxis_list = self._re_matches(re.compile("xaxis\d*"), self._layout.keys())
118+
self._yaxis_list = self._re_matches(re.compile("yaxis\d*"), self._layout.keys())
119+
# edge case: an empty `go.Figure()` does not yet contain axes keys
120+
if not len(self._xaxis_list):
121+
self._xaxis_list = ["xaxis"]
122+
self._yaxis_list = ["yaxis"]
123+
124+
# Make sure to reset the layout its range
125+
self.update_layout(
126+
{
127+
axis: {"autorange": True, "range": None}
128+
for axis in self._xaxis_list + self._yaxis_list
129+
}
130+
)
113131

114132
def _print(self, *values):
115133
"""Helper method for printing if ``verbose`` is set to True."""
@@ -650,19 +668,24 @@ def _parse_get_trace_props(
650668
return _hf_data_container(hf_x, hf_y, hf_text, hf_hovertext)
651669

652670
def _construct_hf_data_dict(
653-
self, dc, trace, downsampler, max_n_samples: int, offset=0
671+
self,
672+
dc,
673+
trace,
674+
downsampler: AbstractSeriesAggregator | None,
675+
max_n_samples: int | None,
676+
offset=0,
654677
) -> dict:
655-
"""Create the `hf_data` dict item which will be put in the `_hf_data` property.
678+
"""Create the `hf_data` dict which will be put in the `_hf_data` property.
656679
657680
Parameters
658681
----------
659682
dc : _hf_data_container
660683
The hf_data container, withholding the parsed hf-data
661684
trace : BaseTraceType
662685
The trace.
663-
downsampler : AbstractSeriesAggregator
686+
downsampler : AbstractSeriesAggregator | None
664687
The downsampler which will be used.
665-
max_n_samples : int
688+
max_n_samples : int | None
666689
The max number of output samples.
667690
668691
Returns
@@ -672,7 +695,6 @@ def _construct_hf_data_dict(
672695
"""
673696
# We will re-create this each time as hf_x and hf_y withholds
674697
# high-frequency data
675-
# index = pd.Index(hf_x, copy=False, name="timestamp")
676698
hf_series = self._to_hf_series(x=dc.x, y=dc.y)
677699

678700
# Checking this now avoids less interpretable `KeyError` when resampling
@@ -689,13 +711,25 @@ def _construct_hf_data_dict(
689711
# & (3) store a hf_data entry for the corresponding trace,
690712
# identified by its UUID
691713
axis_type = "date" if isinstance(dc.x, pd.DatetimeIndex) else "linear"
692-
d = self._global_downsampler if downsampler is None else downsampler
714+
715+
default_n_samples = False
716+
if max_n_samples is None:
717+
default_n_samples = True
718+
max_n_samples = self._global_n_shown_samples
719+
720+
default_downsampler = False
721+
if downsampler is None:
722+
default_downsampler = True
723+
downsampler = self._global_downsampler
724+
693725
return {
694726
"max_n_samples": max_n_samples,
727+
"default_n_samples": default_n_samples,
695728
"x": dc.x,
696729
"y": dc.y,
697730
"axis_type": axis_type,
698-
"downsampler": d,
731+
"downsampler": downsampler,
732+
"default_downsampler": default_downsampler,
699733
"text": dc.text,
700734
"hovertext": dc.hovertext,
701735
}
@@ -808,30 +842,35 @@ def add_trace(
808842
also storing the low-frequency series in the back-end.
809843
810844
"""
811-
if max_n_samples is None:
812-
max_n_samples = self._global_n_shown_samples
845+
# to comply with the plotly data input acceptance behavior
846+
if isinstance(trace, (list, tuple)):
847+
raise ValueError("Trace must be either a dict or a BaseTraceType")
813848

814-
# First add an UUID, as each (even the non-hf_data traces), must contain this
815-
# key for comparison
816-
uuid = str(uuid4())
849+
max_out_s = (
850+
self._global_n_shown_samples if max_n_samples is None else max_n_samples
851+
)
817852

818853
# Validate the trace and convert to a trace object
819854
if not isinstance(trace, BaseTraceType):
820855
trace = self._data_validator.validate_coerce(trace)[0]
821-
trace.uid = uuid
856+
857+
# First add an UUID, as each (even the non-hf_data traces), must contain this
858+
# key for comparison. If the trace already has an UUID, we will keep it.
859+
uuid_str = str(uuid4()) if trace.uid is None else trace.uid
860+
trace.uid = uuid_str
822861

823862
dc = self._parse_get_trace_props(trace, hf_x, hf_y, hf_text, hf_hovertext)
824863

825864
n_samples = len(dc.x)
826865
# These traces will determine the autoscale RANGE!
827866
# -> so also store when `limit_to_view` is set.
828867
if trace["type"].lower() in self._high_frequency_traces:
829-
if n_samples > max_n_samples or limit_to_view:
868+
if n_samples > max_out_s or limit_to_view:
830869
self._print(
831-
f"\t[i] DOWNSAMPLE {trace['name']}\t{n_samples}->{max_n_samples}"
870+
f"\t[i] DOWNSAMPLE {trace['name']}\t{n_samples}->{max_out_s}"
832871
)
833872

834-
self._hf_data[uuid] = self._construct_hf_data_dict(
873+
self._hf_data[uuid_str] = self._construct_hf_data_dict(
835874
dc,
836875
trace=trace,
837876
downsampler=downsampler,
@@ -867,7 +906,7 @@ def add_trace(
867906

868907
def add_traces(
869908
self,
870-
data: List[BaseTraceType | dict],
909+
data: List[BaseTraceType | dict] | BaseTraceType | Dict,
871910
max_n_samples: None | List[int] | int = None,
872911
downsamplers: None
873912
| List[AbstractSeriesAggregator]
@@ -877,6 +916,11 @@ def add_traces(
877916
):
878917
"""Add traces to the figure
879918
919+
.. note::
920+
make sure to look at the :func:`add_trace` function for more info about
921+
**speed optimization**, and dealing with not ``high-frequency`` data, but
922+
still want to resample / limit the data to the front-end view.
923+
880924
Parameters
881925
----------
882926
data : List[BaseTraceType | dict]
@@ -917,6 +961,11 @@ def add_traces(
917961
`Figure.add_traces <https://plotly.com/python-api-reference/generated/plotly.graph_objects.Figure.html#plotly.graph_objects.Figure.add_traces>`_ docs.
918962
919963
"""
964+
# note: Plotly its add_traces also a allows non list-like input e.g. a scatter
965+
# object; the code below is an exact copy of their internally applied parsing
966+
if not isinstance(data, (list, tuple)):
967+
data = [data]
968+
920969
# Convert each trace into a trace object
921970
data = [
922971
self._data_validator.validate_coerce(trace)[0]
@@ -925,6 +974,12 @@ def add_traces(
925974
for trace in data
926975
]
927976

977+
# First add an UUID, as each (even the non-hf_data traces), must contain this
978+
# key for comparison. If the trace already has an UUID, we will keep it.
979+
for trace in data:
980+
uuid_str = str(uuid4()) if trace.uid is None else trace.uid
981+
trace.uid = uuid_str
982+
928983
# Convert the data properties
929984
if isinstance(max_n_samples, (int, np.integer)) or max_n_samples is None:
930985
max_n_samples = [max_n_samples] * len(data)
@@ -942,25 +997,24 @@ def add_traces(
942997
):
943998
continue
944999

945-
max_out = self._global_n_shown_samples if max_out is None else max_out
946-
if len(trace["y"]) <= max_out and not limit_to_view:
1000+
max_out_s = self._global_n_shown_samples if max_out is None else max_out
1001+
if not limit_to_view and (trace.y is None or len(trace.y) <= max_out_s):
9471002
continue
9481003

949-
d = self._global_downsampler if downsampler is None else downsampler
950-
951-
uuid_str = str(uuid4())
952-
trace["uid"] = uuid_str
9531004
dc = self._parse_get_trace_props(trace)
954-
self._hf_data[uuid_str] = self._construct_hf_data_dict(
955-
dc, trace=trace, downsampler=d, max_n_samples=max_out, offset=i
1005+
self._hf_data[trace.uid] = self._construct_hf_data_dict(
1006+
dc,
1007+
trace=trace,
1008+
downsampler=downsampler,
1009+
max_n_samples=max_out,
1010+
offset=i,
9561011
)
9571012

9581013
trace = trace._props # convert the trace into a dict
9591014
trace = {k: trace[k] for k in set(trace.keys()).difference(set(dc._fields))}
9601015

9611016
trace = self._check_update_trace_data(trace)
9621017
assert trace is not None
963-
9641018
data[i] = trace
9651019

9661020
super(self._figure_class, self).add_traces(data, **traces_kwargs)
@@ -973,6 +1027,41 @@ def _clear_figure(self):
9731027
self._layout = {}
9741028
self.layout = {}
9751029

1030+
def _copy_hf_data(self, hf_data: dict, adjust_default_values: bool = False) -> dict:
1031+
"""Copy (i.e. create a new key reference, not a deep copy) of a hf_data dict.
1032+
1033+
Parameters
1034+
----------
1035+
hf_data : dict
1036+
The hf_data dict, having the trace 'uid' as key and the
1037+
hf-data, together with its aggregation properties as dict-values
1038+
adjust_default_values: bool
1039+
Whether the default values (of the downsampler, max # shown samples) will
1040+
be adjusted according to the values of this object, by default False
1041+
1042+
Returns
1043+
-------
1044+
dict
1045+
The copied (& default values adjusted) output dict.
1046+
1047+
"""
1048+
hf_data_cp = {
1049+
k: {
1050+
k_: hf_data[k][k_]
1051+
for k_ in set(v.keys()) # .difference(_hf_data_container._fields)
1052+
}
1053+
for k, v in hf_data.items()
1054+
}
1055+
1056+
if adjust_default_values:
1057+
for hf_props in hf_data_cp.values():
1058+
if hf_props.get("default_downsampler", False):
1059+
hf_props["downsampler"] = self._global_downsampler
1060+
if hf_props.get("default_n_samples", False):
1061+
hf_props["max_n_samples"] = self._global_n_shown_samples
1062+
1063+
return hf_data_cp
1064+
9761065
def replace(self, figure: go.Figure, convert_existing_traces: bool = True):
9771066
"""Replace the current figure layout with the passed figure object.
9781067

0 commit comments

Comments
 (0)