Skip to content

Commit 7686284

Browse files
committed
Clean Commit
1 parent 1fe6595 commit 7686284

File tree

3 files changed

+1150
-0
lines changed

3 files changed

+1150
-0
lines changed

sklearn/model_selection/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
BaseShuffleSplit,
2222
GroupKFold,
2323
GroupShuffleSplit,
24+
GroupTimeSeriesSplit,
2425
KFold,
2526
LeaveOneGroupOut,
2627
LeaveOneOut,

sklearn/model_selection/_split.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"BaseCrossValidator",
4040
"GroupKFold",
4141
"GroupShuffleSplit",
42+
"GroupTimeSeriesSplit",
4243
"KFold",
4344
"LeaveOneGroupOut",
4445
"LeaveOneOut",
@@ -2430,6 +2431,146 @@ def split(self, X, y, groups=None):
24302431
return super().split(X, y, groups)
24312432

24322433

2434+
class GroupTimeSeriesSplit(GroupsConsumerMixin, _BaseKFold):
2435+
"""Time Series cross-validator variant with non-overlapping groups.
2436+
2437+
Provides train/test indices to split time series data samples that are
2438+
observed at fixed time intervals according to a third-party provided group.
2439+
In each split, test indices must be higher than before, and thus shuffling
2440+
in cross validator is inappropriate.
2441+
2442+
The same group will not appear in two different folds (the number of
2443+
distinct groups has to be at least equal to the number of folds).
2444+
2445+
Note that, unlike standard cross-validation methods, successive training
2446+
sets are supersets of those that come before them.
2447+
2448+
The group labels should be contiguous such as the following:
2449+
2450+
.. code-block::
2451+
2452+
valid_groups = np.array([
2453+
'a', 'a', 'b', 'b', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'd', 'd', 'd'
2454+
])
2455+
2456+
Non-contiguous groups like below will give an error.
2457+
2458+
.. code-block::
2459+
2460+
invalid_groups = np.array([
2461+
'a', 'a', 'b', 'b', 'b', 'b', 'b', 'a', 'c', 'c', 'c', 'b', 'd', 'd'
2462+
])
2463+
2464+
Read more in the :ref:`User Guide <cross_validation>`.
2465+
2466+
Parameters
2467+
----------
2468+
n_splits : int, default=5
2469+
Number of splits. Must be at least 2.
2470+
2471+
max_train_size : int, default=None
2472+
Maximum size for a single training group.
2473+
2474+
test_size : int, default=None
2475+
Used to limit the size of the test group.
2476+
2477+
gap : int, default=0
2478+
Number of groups to exclude from the end of each training group before
2479+
the test group.
2480+
2481+
Examples
2482+
--------
2483+
>>> import numpy as np
2484+
>>> from sklearn.model_selection import GroupTimeSeriesSplit
2485+
>>> groups = np.array([
2486+
... 'a', 'a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'd', 'd', 'd']
2487+
... )
2488+
>>> gtss = GroupTimeSeriesSplit(n_splits=3)
2489+
>>> for train_idx, test_idx in gtss.split(groups, groups=groups):
2490+
... print("TRAIN:", train_idx, "TEST:", test_idx)
2491+
... print(
2492+
... "TRAIN GROUP:", groups[train_idx], "TEST GROUP:", groups[test_idx]
2493+
... )
2494+
TRAIN: [0 1 2 3] TEST: [4 5 6]
2495+
TRAIN GROUP: ['a' 'a' 'a' 'a'] TEST GROUP: ['b' 'b' 'b']
2496+
TRAIN: [0 1 2 3 4 5 6] TEST: [ 7 8 9 10]
2497+
TRAIN GROUP: ['a' 'a' 'a' 'a' 'b' 'b' 'b'] TEST GROUP: ['c' 'c' 'c' 'c']
2498+
TRAIN: [ 0 1 2 3 4 5 6 7 8 9 10] TEST: [11 12 13]
2499+
TRAIN GROUP: ['a' 'a' 'a' 'a' 'b' 'b' 'b' 'c' 'c' 'c' 'c'] TEST GROUP: ['d' 'd' 'd']
2500+
"""
2501+
2502+
def __init__(self, n_splits=5, *, max_train_size=None, test_size=None, gap=0):
2503+
super().__init__(n_splits, shuffle=False, random_state=None)
2504+
self.max_train_size = max_train_size
2505+
self.test_size = test_size
2506+
self.gap = gap
2507+
2508+
def split(self, X, y=None, groups=None):
2509+
"""Generate indices to split data into training and test set.
2510+
2511+
Parameters
2512+
----------
2513+
X : array-like of shape (n_samples, n_features)
2514+
Training data, where `n_samples` is the number of samples
2515+
and `n_features` is the number of features.
2516+
2517+
y : array-like of shape (n_samples,)
2518+
Always ignored, exists for compatibility.
2519+
2520+
groups : array-like of shape (n_samples,)
2521+
Group labels for the samples used while splitting the dataset into
2522+
train/test set.
2523+
2524+
Yields
2525+
------
2526+
train : ndarray
2527+
The training set indices for that split.
2528+
2529+
test : ndarray
2530+
The testing set indices for that split.
2531+
"""
2532+
if groups is None:
2533+
raise ValueError("The 'groups' parameter should not be None.")
2534+
X, y, groups = indexable(X, y, groups)
2535+
n_folds = self.n_splits + 1
2536+
# `np.unique` will reorder the group. We need to keep the original
2537+
# ordering.
2538+
reordered_unique_groups, indices = np.unique(groups, return_index=True)
2539+
unique_groups = reordered_unique_groups[np.argsort(indices)]
2540+
n_groups = len(unique_groups)
2541+
if n_folds > n_groups:
2542+
raise ValueError(
2543+
f"Cannot have number of folds={n_folds} "
2544+
f"greater than the number of groups={n_groups}"
2545+
)
2546+
seen_groups = set()
2547+
prev_group = None
2548+
for idx, group in enumerate(groups):
2549+
if group != prev_group and group in seen_groups:
2550+
raise ValueError(
2551+
"The groups should be contiguous."
2552+
" Found a non-contiguous group at"
2553+
f" index={idx}"
2554+
)
2555+
prev_group = group
2556+
seen_groups.add(group)
2557+
tss = TimeSeriesSplit(
2558+
gap=self.gap,
2559+
max_train_size=None,
2560+
n_splits=self.n_splits,
2561+
test_size=None,
2562+
)
2563+
for train_idx, test_idx in tss.split(unique_groups):
2564+
train_array = np.where(np.isin(groups, unique_groups[train_idx]))[0]
2565+
test_array = np.where(np.isin(groups, unique_groups[test_idx]))[0]
2566+
train_end = len(train_array)
2567+
if self.max_train_size and self.max_train_size < train_end:
2568+
train_array = train_array[train_end - self.max_train_size : train_end]
2569+
if self.test_size:
2570+
test_array = test_array[: self.test_size]
2571+
yield train_array, test_array
2572+
2573+
24332574
def _validate_shuffle_split(n_samples, test_size, train_size, default_test_size=None):
24342575
"""
24352576
Validation helper to check if the train/test sizes are meaningful w.r.t. the

0 commit comments

Comments
 (0)