Skip to content

Commit a83f569

Browse files
committed
plot grouptimeseries split
1 parent dbe4502 commit a83f569

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/model_selection/plot_cv_indices.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sklearn.model_selection import (
2323
GroupKFold,
2424
GroupShuffleSplit,
25+
GroupTimeSeriesSplit,
2526
KFold,
2627
ShuffleSplit,
2728
StratifiedGroupKFold,
@@ -60,6 +61,8 @@
6061
group_prior = rng.dirichlet([2] * 10)
6162
groups = np.repeat(np.arange(10), rng.multinomial(100, group_prior))
6263

64+
# Un-Evenly spaced groups repeated once
65+
unevengroups = np.hstack([[group] * 10 if group % 3 else [group] * 5 for group in range(12)])
6366

6467
def visualize_groups(classes, groups, name):
6568
# Visualize dataset groups
@@ -197,6 +200,7 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
197200
StratifiedKFold,
198201
StratifiedGroupKFold,
199202
GroupShuffleSplit,
203+
GroupTimeSeriesSplit,
200204
StratifiedShuffleSplit,
201205
TimeSeriesSplit,
202206
]
@@ -205,6 +209,8 @@ def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10):
205209
for cv in cvs:
206210
this_cv = cv(n_splits=n_splits)
207211
fig, ax = plt.subplots(figsize=(6, 3))
212+
if cv == GroupTimeSeriesSplit:
213+
groups = unevengroups
208214
plot_cv_indices(this_cv, X, y, groups, ax, n_splits)
209215

210216
ax.legend(

0 commit comments

Comments
 (0)