Skip to content

Commit 0cee8b7

Browse files
Added plot functionality
1 parent 51e7697 commit 0cee8b7

File tree

8 files changed

+396
-75
lines changed

8 files changed

+396
-75
lines changed

examples/example_calendar_split.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626

2727
# Example 1
2828
# ---------
29-
29+
3030
model = CalendarSplit(
3131
data.index,
32-
n_train = 3,
33-
n_test = 1,
32+
n_train=3,
33+
n_test=1,
3434
every='YS',
3535
sample_labels=["IS", "OOS"]
3636
)
@@ -98,15 +98,15 @@
9898

9999
# Example 2
100100
# ---------
101-
101+
102102
model = CalendarSplit(
103103
data.index,
104-
n_train = 4,
105-
n_test = 2,
104+
n_train=4,
105+
n_test=2,
106106
every='QS',
107107
sample_labels=["IS", "OOS"]
108108
)
109-
109+
110110
# Base
111111
print(model.splits_arr)
112112
print(model.splits_arr.dtype)

fold/model_selection/base.py

Lines changed: 104 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import abc
2+
from dataclasses import asdict
3+
import typing as tp
24
import numpy as np
35
import pandas as pd
46
from pandas.tseries.offsets import BaseOffset as Offset
5-
import typing as tp
7+
from plotly.basedatatypes import BaseFigure
8+
from plotly.graph_objects import FigureWidget
9+
import plotly.express as px
610

711
from fold.model_selection.config import Config, UpdateConfig
812
from fold.model_selection.duration import Duration
913
from fold.model_selection.coverage import Coverage
14+
from fold.model_selection.plots import Layout, Trace, Heatmap
1015
from fold.tools import (
1116
substitute,
1217
BaseTool,
@@ -119,7 +124,7 @@ def __init__(
119124
self._ndim = 2 if len(self._sample_labels) > 1 else config.ndim
120125
self._index = index
121126
self._splits_arr = splits_arr
122-
127+
123128
super().__init__(
124129
index=self.index,
125130
splits_arr=self.splits_arr,
@@ -128,7 +133,7 @@ def __init__(
128133
split_labels=self.split_labels,
129134
sample_labels=self.sample_labels
130135
)
131-
136+
132137
@property
133138
def split_labels(self) -> pd.Index:
134139
"""
@@ -587,6 +592,98 @@ def _get_range_meta(i, j):
587592

588593
return pd.Series(range_objs, index=keys, dtype=object)
589594

595+
def plot(
596+
self,
597+
mask_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
598+
trace_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
599+
add_trace_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
600+
fig: tp.Optional[tp.Any] = None,
601+
figure_kwargs: tp.Optional[tp.Dict[str, tp.Any]] = None,
602+
layout_kwargs: tp.Optional[Layout] = None,
603+
) -> BaseFigure:
604+
"""
605+
Plot Flow graph of training vs test data.
606+
607+
Parameters
608+
----------
609+
mask_kwargs :tp.Dict[str, tp.Any]
610+
Keyword arguments passed to `Splitter.get_iter_set_masks`.
611+
trace_kwargs : tp.Dict[str, tp.Any]
612+
Keyword arguments passed to `plotly` Heatmap`.
613+
add_trace_kwargs : tp.Dict[str, tp.Any]
614+
Keyword arguments passed to `add_trace`.
615+
fig : FigureWidget
616+
Figure to add traces to.
617+
**layout_kwargs
618+
Keyword arguments for layout.
619+
620+
Examples
621+
--------
622+
```pycon
623+
>>> from fold import SklearnFold
624+
>>> import pandas as pd
625+
>>> from sklearn.model_selection import TimeSeriesSplit
626+
627+
>>> index = pd.date_range("2010", "2024", freq="W")
628+
>>> model = SklearnFold(index, TimeSeriesSplit())
629+
>>> model.plot().show()
630+
```
631+
632+
"""
633+
trace_kwargs = trace_kwargs or {}
634+
add_trace_kwargs = add_trace_kwargs or {}
635+
figure_kwargs = figure_kwargs or {}
636+
mask_kwargs = mask_kwargs or {}
637+
layout_kwargs = asdict(Layout())
638+
639+
# Figure
640+
if fig is None:
641+
fig = FigureWidget(**figure_kwargs)
642+
643+
# Update layout
644+
fig.update_layout(**layout_kwargs)
645+
646+
# Colors
647+
colorway = (
648+
fig.layout.colorway
649+
if fig.layout.colorway is not None
650+
else fig.layout.template.layout.colorway
651+
)
652+
if len(self.sample_labels) > len(colorway):
653+
colorway = px.colors.qualitative.Alphabet
654+
655+
if self.n_splits > 0 and self.n_samples > 0:
656+
657+
# Loop through Coverage.get_iter_sample_masks
658+
for i, mask in enumerate(self.get_iter_sample_masks()):
659+
# Data
660+
df = mask.ffill()
661+
df[mask] = i
662+
data = df.transpose().iloc[::-1]
663+
664+
# Settings
665+
color = colorway[i % len(colorway)]
666+
name = str(self.sample_labels[i])
667+
default_trace = asdict(Trace()) | dict(
668+
x=data.columns,
669+
y=data.index,
670+
legendgroup=str(self.sample_labels[i]),
671+
name=name,
672+
colorscale=[color, color],
673+
hovertemplate="%{x}<br>Split: %{y}<br>Sample: " + name,
674+
)
675+
676+
# Plot
677+
fig = Heatmap(
678+
data=data,
679+
trace_kwargs=default_trace | trace_kwargs,
680+
add_trace_kwargs=add_trace_kwargs,
681+
is_y_category=True,
682+
fig=fig,
683+
).figure
684+
685+
return fig
686+
590687

591688
class BasePurgedCV:
592689
"""
@@ -781,13 +878,13 @@ def split(
781878
pred_times = pd.Series(pred_times, index=X.index)
782879
else:
783880
checks.assert_instance_of(
784-
pred_times,
785-
pd.Series,
881+
pred_times,
882+
pd.Series,
786883
arg_name="pred_times"
787884
)
788885
checks.assert_index_equal(
789-
X.index,
790-
pred_times.index,
886+
X.index,
887+
pred_times.index,
791888
check_names=False
792889
)
793890
if eval_times is None:

fold/model_selection/plots.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import abc
2+
from dataclasses import dataclass, field
3+
import typing as tp
4+
5+
import numpy as np
6+
7+
import plotly.graph_objects as go
8+
from plotly.basedatatypes import BaseFigure, BaseTraceType
9+
10+
11+
class BasePlotter(abc.ABC):
12+
def __init__(self, figure: BaseFigure, traces: tp.Tuple[BaseTraceType, ...]):
13+
"""Base trace updating class."""
14+
self._figure = figure
15+
self._traces = traces
16+
17+
@property
18+
def figure(self) -> BaseFigure:
19+
"""Figure."""
20+
return self._figure
21+
22+
@property
23+
def traces(self) -> tp.Tuple[BaseTraceType, ...]:
24+
"""Traces to update."""
25+
return self._traces
26+
27+
@classmethod
28+
@abc.abstractmethod
29+
def update_trace(
30+
cls,
31+
trace: BaseTraceType,
32+
data: np.ndarray,
33+
*args,
34+
**kwargs
35+
):
36+
"""Update one trace."""
37+
pass
38+
39+
@abc.abstractmethod
40+
def update(self, *args, **kwargs):
41+
"""Update all traces using new data."""
42+
pass
43+
44+
45+
@dataclass
46+
class Layout:
47+
width: int = 700
48+
height: int = 350
49+
margin: dict = field(
50+
default_factory=lambda: dict(t=30, b=30, l=30, r=30)
51+
)
52+
legend: dict = field(
53+
default_factory=lambda: dict(
54+
x=1,
55+
y=1.02, orientation="h",
56+
yanchor="bottom",
57+
xanchor="right",
58+
traceorder="normal"
59+
)
60+
)
61+
# Scikit-learn colors
62+
colorway: list = field(
63+
default_factory=lambda: [
64+
"#007AB8",
65+
"#F7931E",
66+
"#505050"
67+
]
68+
)
69+
70+
71+
@dataclass
72+
class Trace:
73+
hoverongaps: bool = False
74+
showscale: bool = False
75+
showlegend: bool = True
76+
77+
78+
class Heatmap(BasePlotter):
79+
"""
80+
Create a heatmap plot.
81+
82+
Parameters
83+
----------
84+
data : array_like, optional
85+
Data in any format that can be converted to NumPy.
86+
is_x_category : bool, optional
87+
Whether X-axis is a categorical axis. Default is False.
88+
is_y_category : bool, optional
89+
Whether Y-axis is a categorical axis. Default is False.
90+
trace_kwargs : dict, optional
91+
Keyword arguments passed to `plotly.graph_objects.Heatmap`.
92+
add_trace_kwargs : dict, optional
93+
Keyword arguments passed to `add_trace`.
94+
fig : Figure or FigureWidget, optional
95+
Figure to add traces to.
96+
**layout_kwargs : dict
97+
Keyword arguments for layout.
98+
"""
99+
100+
def __init__(
101+
self,
102+
data: tp.Optional[np.ndarray],
103+
is_x_category: tp.Optional[bool] = False,
104+
is_y_category: tp.Optional[bool] = False,
105+
trace_kwargs: tp.Dict[str, tp.Any] = None,
106+
add_trace_kwargs: tp.Dict[str, tp.Any] = None,
107+
figure_kwargs: tp.Dict[str, tp.Any] = None,
108+
fig: tp.Optional[BaseFigure] = None,
109+
**layout_kwargs
110+
):
111+
trace = go.Heatmap(**trace_kwargs)
112+
113+
if data is not None:
114+
self.update_trace(trace, data)
115+
116+
fig.add_trace(trace, **add_trace_kwargs)
117+
118+
axis_kwargs = dict()
119+
if is_x_category:
120+
if fig.data[-1]["xaxis"] is not None:
121+
axis_kwargs["xaxis" + fig.data[-1]["xaxis"][1:]] = dict(type="category")
122+
else:
123+
axis_kwargs["xaxis"] = dict(type="category")
124+
125+
if is_y_category:
126+
if fig.data[-1]["yaxis"] is not None:
127+
axis_kwargs["yaxis" + fig.data[-1]["yaxis"][1:]] = dict(type="category")
128+
else:
129+
axis_kwargs["yaxis"] = dict(type="category")
130+
131+
fig.update_layout(**axis_kwargs)
132+
fig.update_layout(**layout_kwargs)
133+
134+
super().__init__(fig, (fig.data[-1],))
135+
136+
@classmethod
137+
def update_trace(
138+
cls,
139+
trace: BaseTraceType,
140+
data: np.ndarray,
141+
*args,
142+
**kwargs
143+
):
144+
trace.z = data
145+
146+
def update(self, data: np.ndarray):
147+
with self.fig.batch_update():
148+
self.update_trace(self.traces[0], data)

0 commit comments

Comments
 (0)