Skip to content

Commit d257261

Browse files
Genusterlarsoner
andauthored
BUG: Decoding compliance with sklearn [circle deploy] (#13393)
Co-authored-by: Eric Larson <[email protected]>
1 parent b00cf99 commit d257261

16 files changed

+259
-91
lines changed

doc/changes/dev/13393.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make decoding classes sklearn-compliant, by `Gennadiy Belonosov`_.

mne/decoding/_fixes.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,44 @@ def validate_data(
130130
out = X, y
131131

132132
return out
133+
134+
135+
def _check_n_features_3d(estimator, X, reset):
136+
"""Set the `n_features_in_` attribute, or check against it on an estimator.
137+
138+
Sklearn takes n_features from X.shape[1], but we need X.shape[-1]
139+
140+
Parameters
141+
----------
142+
estimator : estimator instance
143+
The estimator to validate the input for.
144+
145+
X : {ndarray, sparse matrix} of shape ([n_epochs], n_samples, n_features)
146+
The input samples.
147+
148+
reset : bool
149+
If True, the `n_features_in_` attribute is set to `X.shape[1]`.
150+
If False and the attribute exists, then check that it is equal to
151+
`X.shape[1]`. If False and the attribute does *not* exist, then
152+
the check is skipped.
153+
.. note::
154+
It is recommended to call reset=True in `fit` and in the first
155+
call to `partial_fit`. All other methods that validate `X`
156+
should set `reset=False`.
157+
"""
158+
n_features = X.shape[-1]
159+
if reset:
160+
estimator.n_features_in_ = n_features
161+
return
162+
163+
if not hasattr(estimator, "n_features_in_"):
164+
# Skip this check if the expected number of expected input features
165+
# was not recorded by calling fit first. This is typically the case
166+
# for stateless transformers.
167+
return
168+
169+
if n_features != estimator.n_features_in_:
170+
raise ValueError(
171+
f"X has {n_features} features, but {estimator.__class__.__name__} "
172+
f"is expecting {estimator.n_features_in_} features as input."
173+
)

mne/decoding/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __sklearn_tags__(self):
334334
tags.target_tags.one_d_labels = True
335335
tags.input_tags.two_d_array = True
336336
tags.input_tags.three_d_array = True
337+
tags.requires_fit = True
337338
return tags
338339

339340

mne/decoding/csp.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def __init__(
160160
R_func=sum,
161161
)
162162

163+
def __sklearn_tags__(self):
164+
"""Tag the transformer."""
165+
tags = super().__sklearn_tags__()
166+
tags.target_tags.required = True
167+
tags.target_tags.multi_output = True
168+
return tags
169+
163170
def _validate_params(self, *, y):
164171
_validate_type(self.n_components, int, "n_components")
165172
if hasattr(self, "cov_est"):
@@ -187,7 +194,10 @@ def _validate_params(self, *, y):
187194
self.classes_ = np.unique(y)
188195
n_classes = len(self.classes_)
189196
if n_classes < 2:
190-
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")
197+
raise ValueError(
198+
"y should be a 1d array with more than two classes, "
199+
f"but got {n_classes} class from {y}"
200+
)
191201
elif n_classes > 2 and self.component_order == "alternate":
192202
raise ValueError(
193203
"component_order='alternate' requires two classes, but data contains "
@@ -756,6 +766,12 @@ def __init__(
756766
delattr(self, "cov_est")
757767
delattr(self, "norm_trace")
758768

769+
def __sklearn_tags__(self):
770+
"""Tag the transformer."""
771+
tags = super().__sklearn_tags__()
772+
tags.target_tags.multi_output = False
773+
return tags
774+
759775
def fit(self, X, y):
760776
"""Estimate the SPoC decomposition on epochs.
761777

mne/decoding/receptive_field.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.metrics import r2_score
1717

1818
from ..utils import _validate_type, fill_doc, pinv
19+
from ._fixes import _check_n_features_3d, validate_data
1920
from .base import _check_estimator, get_coef
2021
from .time_delaying_ridge import TimeDelayingRidge
2122

@@ -125,7 +126,7 @@ def __init__(
125126
self.tmax = tmax
126127
self.sfreq = sfreq
127128
self.feature_names = feature_names
128-
self.estimator = 0.0 if estimator is None else estimator
129+
self.estimator = estimator
129130
self.fit_intercept = fit_intercept
130131
self.scoring = scoring
131132
self.patterns = patterns
@@ -152,6 +153,19 @@ def __repr__(self): # noqa: D105
152153
s += f"scored ({self.scoring})"
153154
return f"<ReceptiveField | {s}>"
154155

156+
def __sklearn_tags__(self):
157+
"""..."""
158+
from sklearn.utils import RegressorTags
159+
160+
tags = super().__sklearn_tags__()
161+
tags.estimator_type = "regressor"
162+
tags.regressor_tags = RegressorTags()
163+
tags.input_tags.three_d_array = True
164+
tags.target_tags.one_d_labels = True
165+
tags.target_tags.multi_output = True
166+
tags.target_tags.required = True
167+
return tags
168+
155169
def _delay_and_reshape(self, X, y=None):
156170
"""Delay and reshape the variables."""
157171
if not isinstance(self.estimator_, TimeDelayingRidge):
@@ -169,6 +183,32 @@ def _delay_and_reshape(self, X, y=None):
169183
y = y.reshape(-1, y.shape[-1], order="F")
170184
return X, y
171185

186+
def _check_data(self, X, y=None, reset=False):
187+
if reset:
188+
X, y = validate_data(
189+
self,
190+
X=X,
191+
y=y,
192+
reset=reset,
193+
validate_separately=( # to take care of 3D y
194+
dict(allow_nd=True, ensure_2d=False),
195+
dict(allow_nd=True, ensure_2d=False),
196+
),
197+
)
198+
else:
199+
X = validate_data(self, X=X, allow_nd=True, ensure_2d=False, reset=reset)
200+
_check_n_features_3d(self, X, reset)
201+
return X, y
202+
203+
def _validate_params(self, X):
204+
if self.scoring not in _SCORERS.keys():
205+
raise ValueError(
206+
f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring}"
207+
)
208+
self.sfreq_ = float(self.sfreq)
209+
if self.tmin > self.tmax:
210+
raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})")
211+
172212
def fit(self, X, y):
173213
"""Fit a receptive field model.
174214
@@ -184,22 +224,18 @@ def fit(self, X, y):
184224
self : instance
185225
The instance so you can chain operations.
186226
"""
187-
if self.scoring not in _SCORERS.keys():
188-
raise ValueError(
189-
f"scoring must be one of {sorted(_SCORERS.keys())}, got {self.scoring} "
190-
)
191-
self.sfreq_ = float(self.sfreq)
227+
X, y = self._check_data(X, y, reset=True)
228+
self._validate_params(X)
192229
X, y, _, self._y_dim = self._check_dimensions(X, y)
193230

194-
if self.tmin > self.tmax:
195-
raise ValueError(f"tmin ({self.tmin}) must be at most tmax ({self.tmax})")
196231
# Initialize delays
197232
self.delays_ = _times_to_delays(self.tmin, self.tmax, self.sfreq_)
198233

199234
# Define the slice that we should use in the middle
200235
self.valid_samples_ = _delays_to_slice(self.delays_)
201236

202-
if isinstance(self.estimator, numbers.Real):
237+
if self.estimator is None or isinstance(self.estimator, numbers.Real):
238+
alpha = self.estimator if self.estimator is not None else 0.0
203239
if self.fit_intercept is None:
204240
self.fit_intercept_ = True
205241
else:
@@ -208,7 +244,7 @@ def fit(self, X, y):
208244
self.tmin,
209245
self.tmax,
210246
self.sfreq_,
211-
alpha=self.estimator,
247+
alpha=alpha,
212248
fit_intercept=self.fit_intercept_,
213249
n_jobs=self.n_jobs,
214250
edge_correction=self.edge_correction,
@@ -259,6 +295,12 @@ def fit(self, X, y):
259295

260296
# Inverse-transform model weights
261297
if self.patterns:
298+
n_total_samples = n_times * n_epochs
299+
if n_total_samples < 2:
300+
raise ValueError(
301+
"Cannot compute patterns with only one sample; "
302+
f"got n_samples = {n_total_samples}."
303+
)
262304
if isinstance(self.estimator_, TimeDelayingRidge):
263305
cov_ = self.estimator_.cov_ / float(n_times * n_epochs - 1)
264306
y = y.reshape(-1, y.shape[-1], order="F")
@@ -300,7 +342,10 @@ def predict(self, X):
300342
"""
301343
if not hasattr(self, "delays_"):
302344
raise NotFittedError("Estimator has not been fit yet.")
345+
346+
X, _ = self._check_data(X)
303347
X, _, X_dim = self._check_dimensions(X, None, predict=True)[:3]
348+
304349
del _
305350
# convert to sklearn and back
306351
pred_shape = X.shape[:-1]
@@ -384,7 +429,10 @@ def _check_dimensions(self, X, y, predict=False):
384429
)
385430
else:
386431
raise ValueError(
387-
f"X must be shape (n_times[, n_epochs], n_features), got {X.shape}"
432+
"X must be shape (n_times[, n_epochs], n_features), "
433+
f"got {X.shape}. Reshape your data to 2D or 3D "
434+
"(e.g., array.reshape(-1, 1) for a single feature, "
435+
"or array.reshape(1, -1) for a single sample)."
388436
)
389437
if y is not None:
390438
if X.shape[0] != y.shape[0]:

mne/decoding/search_light.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,11 @@ def _transform(self, X, method):
190190
y_pred = np.concatenate(y_pred, axis=1)
191191
if orig_method == "transform":
192192
y_pred = y_pred.astype(X.dtype)
193-
if orig_method == "predict_proba" and not is_nd:
194-
y_pred = y_pred[:, 0, :]
193+
elif (
194+
orig_method in ("predict", "predict_proba", "decision_function")
195+
and not is_nd
196+
):
197+
y_pred = y_pred.squeeze()
195198
return y_pred
196199

197200
def transform(self, X):
@@ -525,8 +528,11 @@ def _transform(self, X, method):
525528
y_pred = np.concatenate(y_pred, axis=2)
526529
if orig_method == "transform":
527530
y_pred = y_pred.astype(X.dtype)
528-
if orig_method == "predict_proba" and not is_nd:
529-
y_pred = y_pred[:, 0, 0, :]
531+
if (
532+
orig_method in ("predict", "predict_proba", "decision_function")
533+
and not is_nd
534+
):
535+
y_pred = y_pred.squeeze()
530536
return y_pred
531537

532538
def transform(self, X):

mne/decoding/tests/test_ged.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,8 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs):
126126
cov_callable=[partial(_mock_cov_callable, cov_method_params=dict(reg="empirical"))],
127127
mod_ged_callable=[_mock_mod_ged_callable],
128128
dec_type=["single", "multi"],
129-
# XXX: Not covering "ssd" here because test_ssd.py works with 2D data.
130-
# Need to fix its tests first.
131-
restr_type=[
132-
"restricting",
133-
"whitening",
134-
],
135-
R_func=[partial(np.sum, axis=0)],
129+
restr_type=["restricting", "whitening"],
130+
R_func=[None, partial(np.sum, axis=0)],
136131
)
137132

138133
ged_estimators = [_GEDTransformer(**p) for p in ParameterGrid(param_grid)]

mne/decoding/tests/test_receptive_field.py

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -589,22 +589,12 @@ def test_linalg_warning():
589589
@parametrize_with_checks([TimeDelayingRidge(0, 10, 1.0, 0.1, "laplacian", n_jobs=1)])
590590
def test_tdr_sklearn_compliance(estimator, check):
591591
"""Test sklearn estimator compliance."""
592-
# We don't actually comply with a bunch of the regressor specs :(
592+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
593593
ignores = (
594-
"check_supervised_y_no_nan",
595-
"check_regressor",
596-
"check_parameters_default_constructible",
597-
"check_estimators_unfitted",
598-
"_invariance",
599-
"check_complex_data",
600-
"check_estimators_empty_data_messages",
601-
"check_estimators_nan_inf",
602-
"check_supervised_y_2d",
603-
"check_n_features_in",
604-
"check_fit2d_1sample",
605-
"check_fit1d",
606-
"check_fit2d_predict1d",
607-
"check_requires_y_none",
594+
# TDR convolves and thus its output cannot be invariant when
595+
# shuffled or subsampled.
596+
"check_methods_sample_order_invariance",
597+
"check_methods_subset_invariance",
608598
)
609599
if any(ignore in str(check) for ignore in ignores):
610600
return
@@ -615,17 +605,12 @@ def test_tdr_sklearn_compliance(estimator, check):
615605
@parametrize_with_checks([ReceptiveField(-1, 2, 1.0, estimator=Ridge(), patterns=True)])
616606
def test_rf_sklearn_compliance(estimator, check):
617607
"""Test sklearn RF compliance."""
608+
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
618609
ignores = (
619-
"check_parameters_default_constructible",
620-
"_invariance",
621-
"check_fit2d_1sample",
622-
# Should probably fix these?
623-
"check_complex_data",
624-
"check_dtype_object",
625-
"check_estimators_empty_data_messages",
626-
"check_n_features_in",
627-
"check_fit2d_predict1d",
628-
"check_estimators_unfitted",
610+
# RF does time-lagging, so its output cannot be invariant when
611+
# shuffled or subsampled.
612+
"check_methods_sample_order_invariance",
613+
"check_methods_subset_invariance",
629614
)
630615
if any(ignore in str(check) for ignore in ignores):
631616
return

mne/decoding/tests/test_search_light.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -354,13 +354,5 @@ def predict_proba(self, X):
354354
]
355355
)
356356
def test_sklearn_compliance(estimator, check):
357-
"""Test LinearModel compliance with sklearn."""
358-
ignores = (
359-
# TODO: we don't handle singleton right (probably)
360-
"check_classifiers_one_label_sample_weights",
361-
"check_classifiers_classes",
362-
"check_classifiers_train",
363-
)
364-
if any(ignore in str(check) for ignore in ignores):
365-
return
357+
"""Test searchlights compliance with sklearn."""
366358
check(estimator)

mne/decoding/tests/test_ssd.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,14 @@ def test_sklearn_compliance(estimator, check):
623623
"""Test LinearModel compliance with sklearn."""
624624
pytest.importorskip("sklearn", minversion="1.4") # TODO VERSION remove on 1.4+
625625
ignores = (
626-
"check_methods_sample_order_invariance",
627-
# Shape stuff
626+
# Checks below fail because what sklearn passes as (n_samples, n_features)
627+
# is considered (n_channels, n_times) by SSD and creates problems
628+
# when n_channels change between fit and transform.
629+
# Could potentially be fixed by if X.ndim == 2: X = np.expand_dims(X, axis=2)
630+
# in fit and transform instead of axis=0.
631+
# But this will require to drop support for 2D inputs and expect
632+
# user to provide 3D array even if it's a continuous signal.
633+
"check_methods_sample_order_invariance", # SSD is not time-invariant
628634
"check_fit_idempotent",
629635
"check_methods_subset_invariance",
630636
"check_transformer_general",

0 commit comments

Comments
 (0)