Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3ad8c28
initial improvements
Genuster Jul 30, 2025
a3479cf
clone(model) instead of modifying it, change default model, add depre…
Genuster Jul 30, 2025
09dc04f
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 6, 2025
9346f72
make linearmodel work with onevsrestclassifier
Genuster Aug 6, 2025
4e2bcbd
few last fixes
Genuster Aug 6, 2025
568f0ea
add futurewarning test
Genuster Aug 6, 2025
f138d13
add short docstring for predict method
Genuster Aug 6, 2025
838153b
add returns to the predict docstring
Genuster Aug 6, 2025
b7d4c0f
use model.__sklearn__tags__ instead of get_tags
Genuster Aug 6, 2025
9f2354f
use mne's fix for validate_data
Genuster Aug 6, 2025
74177d2
move predict and classes_ back to the wrapped attrs
Genuster Aug 6, 2025
222ed16
add fit_transform test
Genuster Aug 6, 2025
d7ca9ff
__getattr__ can silently catch attribute error from filters_ property
Genuster Aug 6, 2025
175f6c6
make validation compatible with sklearn<1.6
Genuster Aug 7, 2025
40dec84
more old sklearn fixes
Genuster Aug 7, 2025
4cfd798
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 7, 2025
973c898
undo fit check from filters property
Genuster Aug 8, 2025
5e34bd0
add changelog entry
Genuster Aug 8, 2025
452d98b
fix typo in changelog
Genuster Aug 8, 2025
f6b8079
Merge branch 'main' into linearmodel
Genuster Aug 8, 2025
73357a1
undo api changes
Genuster Aug 11, 2025
dd68588
another tiny undo
Genuster Aug 11, 2025
8160516
Merge branch 'main' into linearmodel
Genuster Aug 12, 2025
9397f4e
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 18, 2025
a088332
TST: All examples [circle full]
Genuster Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/changes/dev/13361.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Starting with MNE-Python v1.13, ``model`` parameter of :class:`mne.decoding.LinearModel`
will not be modified, use ``model_`` attribute to access the fitted model.
In addition, the default ``None`` for ``model`` will not automatically set
:class:`sklearn.linear_model.LogisticRegression`, it will need to be provided explicitly.
The ``model`` is expected to be a supervised predictor, i.e. classifier or regressor
(or :class:`sklearn.multiclass.OneVsRestClassifier`), otherwise an error will be raised,
by `Gennadiy Belonosov`_.
135 changes: 98 additions & 37 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@
TransformerMixin,
clone,
is_classifier,
is_regressor,
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, check_X_y, indexable
from sklearn.utils import indexable
from sklearn.utils.validation import check_is_fitted

from ..parallel import parallel_func
from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn
from ..utils import (
_check_option,
_pl,
_validate_type,
logger,
pinv,
verbose,
warn,
)
from ._fixes import validate_data
from ._ged import (
_handle_restr_mat,
_is_cov_pos_semidef,
Expand Down Expand Up @@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
model : object | None
A linear model from scikit-learn with a fit method
that updates a ``coef_`` attribute.
If None the model will be LogisticRegression.
If None the model will be
:class:`sklearn.linear_model.LogisticRegression`.

Attributes
----------
Expand All @@ -364,46 +375,77 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
.. footbibliography::
"""

# TODO: Properly refactor this using
# https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
_model_attr_wrap = (
"transform",
"fit_transform",
"predict",
"predict_proba",
"_estimator_type",
"__tags__",
"predict_log_proba",
"_estimator_type", # remove after sklearn 1.6
"decision_function",
"score",
"classes_",
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
# XXX Remove the clause after warning cycle
if model is None:
model = LogisticRegression(solver="liblinear")
depr_message = (
"Starting with mne-python v1.13 'model' default "
"will change from LogisticRegression to None. "
"From now on please set model=LogisticRegression"
"(solver='liblinear') explicitly."
)
warn(depr_message, FutureWarning)

self.model = model

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags # added in 1.6

# fit method below does not allow sparse data via check_data, we could
# eventually make it smarter if we had to
tags = get_tags(self.model)
tags.input_tags.sparse = False
tags = super().__sklearn_tags__()
# XXX Change self._orig_model to self.model after 'model' warning cycle
model_tags = self._orig_model.__sklearn_tags__()
tags.estimator_type = model_tags.estimator_type
if tags.estimator_type is not None:
model_type_tags = getattr(model_tags, f"{tags.estimator_type}_tags")
setattr(tags, f"{tags.estimator_type}_tags", model_type_tags)
return tags

def __getattr__(self, attr):
"""Wrap to model for some attributes."""
if attr in LinearModel._model_attr_wrap:
return getattr(self.model, attr)
elif attr == "fit_transform" and hasattr(self.model, "fit_transform"):
return super().__getattr__(self, "_fit_transform")
return super().__getattr__(self, attr)
# XXX Change self._orig_model to self.model after 'model' warning cycle
model = self.model_ if "model_" in self.__dict__ else self._orig_model
if attr == "fit_transform" and hasattr(model, "fit_transform"):
return self._fit_transform
else:
return getattr(model, attr)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)

def _fit_transform(self, X, y):
return self.fit(X, y).transform(X)

def _validate_params(self, X):
model = self._orig_model
if isinstance(model, MetaEstimatorMixin):
model = model.estimator
is_predictor = is_regressor(model) or is_classifier(model)
if not is_predictor:
raise ValueError(
"Linear model should be a supervised predictor "
"(classifier or regressor)"
)

# For sklearn < 1.6
try:
self._check_n_features(X, reset=True)
except AttributeError:
pass

def fit(self, X, y, **fit_params):
"""Estimate the coefficients of the linear model.

Expand All @@ -424,25 +466,15 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
raise ValueError(
f"LinearModel only accepts up to 2-dimensional y, got {y.shape} "
"instead."
)
self._validate_params(X)
X, y = validate_data(self, X, y, multi_output=True)

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat
# XXX Change self._orig_model to self.model after 'model' warning cycle
self.model_ = clone(self._orig_model)
self.model_.fit(X, y, **fit_params)

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

inv_Y = 1.0
X = X - X.mean(0, keepdims=True)
if y.ndim == 2 and y.shape[1] != 1:
Expand All @@ -454,18 +486,47 @@ def fit(self, X, y, **fit_params):

@property
def filters_(self):
if hasattr(self.model, "coef_"):
if hasattr(self.model_, "coef_"):
# Standard Linear Model
filters = self.model.coef_
elif hasattr(self.model.best_estimator_, "coef_"):
filters = self.model_.coef_
elif hasattr(self.model_, "estimators_"):
# Linear model with OneVsRestClassifier
filters = np.vstack([est.coef_ for est in self.model_.estimators_])
elif hasattr(self.model_, "best_estimator_") and hasattr(
self.model_.best_estimator_, "coef_"
):
# Linear Model with GridSearchCV
filters = self.model.best_estimator_.coef_
filters = self.model_.best_estimator_.coef_
else:
raise ValueError("model does not have a `coef_` attribute.")
if filters.ndim == 2 and filters.shape[0] == 1:
filters = filters[0]
return filters

# XXX Remove this property after 'model' warning cycle
@property
def model(self):
if "model_" in self.__dict__:
depr_message = (
"Starting with mne-python v1.13 'model' attribute "
"of LinearModel will not be fitted, "
"please use 'model_' instead"
)
warn(depr_message, FutureWarning)
return self.model_
else:
return self._orig_model

# XXX Remove this after 'model' warning cycle
@model.setter
def model(self, value):
self._orig_model = value

# XXX Remove this after 'model' warning cycle
def __repr__(self):
"""Avoid FutureWarning from filter_ when printing the instance."""
return f"LinearModel(model={self._orig_model})"


def _set_cv(cv, estimator=None, X=None, y=None):
"""Set the default CV depending on whether clf is classifier/regressor."""
Expand Down
53 changes: 34 additions & 19 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
is_classifier,
is_regressor,
)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.model_selection import (
GridSearchCV,
KFold,
StratifiedKFold,
cross_val_score,
)
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
Expand Down Expand Up @@ -93,12 +95,11 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
return X, Y, A


@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning")
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
lm_classification = LinearModel()
lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
assert hasattr(lm_classification, "__sklearn_tags__")
if check_version("sklearn", "1.4"):
if check_version("sklearn", "1.6"):
print(lm_classification.__sklearn_tags__())
assert is_classifier(lm_classification.model)
assert is_classifier(lm_classification)
Expand Down Expand Up @@ -200,19 +201,19 @@ def inverse_transform(self, X):
# Retrieve final linear model
filters = get_coef(clf, "filters_", False)
if hasattr(clf, "steps"):
if hasattr(clf.steps[-1][-1].model, "best_estimator_"):
if hasattr(clf.steps[-1][-1].model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.steps[-1][-1].model.best_estimator_.coef_
coefs = clf.steps[-1][-1].model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.steps[-1][-1].model.coef_
coefs = clf.steps[-1][-1].model_.coef_
else:
if hasattr(clf.model, "best_estimator_"):
if hasattr(clf.model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.model.best_estimator_.coef_
coefs = clf.model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.model.coef_
coefs = clf.model_.coef_
if coefs.ndim == 2 and coefs.shape[0] == 1:
coefs = coefs[0]
assert_array_equal(filters, coefs)
Expand Down Expand Up @@ -277,12 +278,13 @@ def test_get_coef_multiclass(n_features, n_targets):
"""Test get_coef on multiclass problems."""
# Check patterns with more than 1 regressor
X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets)
with pytest.warns(FutureWarning, match="'model' default"):
_ = LinearModel()
lm = LinearModel(LinearRegression())
assert not hasattr(lm, "model_")
lm.fit(X, Y)
# TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a
# metaestimator?
assert lm.model is lm.model_
with pytest.warns(FutureWarning, match="'model' attribute of LinearModel"):
assert lm.model is lm.model_
assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
if n_targets == 1:
want_shape = (n_features,)
Expand Down Expand Up @@ -328,9 +330,6 @@ def test_get_coef_multiclass(n_features, n_targets):
(3, 1, 2),
],
)
# TODO: Need to fix this properly in LinearModel
@pytest.mark.filterwarnings("ignore:'multi_class' was depr.*:FutureWarning")
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:")
def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
"""Test a full example with pattern extraction."""
data = np.zeros((10 * n_classes, n_channels, n_times))
Expand All @@ -345,7 +344,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
clf = make_pipeline(
Scaler(epochs.info),
Vectorizer(),
LinearModel(LogisticRegression(random_state=0, multi_class="ovr")),
LinearModel(OneVsRestClassifier(LogisticRegression(random_state=0))),
)
scorer = "roc_auc_ovr_weighted"
time_gen = GeneralizingEstimator(clf, scorer, verbose=True)
Expand All @@ -371,7 +370,7 @@ def test_linearmodel():
"""Test LinearModel class for computing filters and patterns."""
# check categorical target fit in standard linear model
rng = np.random.RandomState(0)
clf = LinearModel()
clf = LinearModel(LogisticRegression(solver="liblinear"))
n, n_features = 20, 3
X = rng.rand(n, n_features)
y = np.arange(n) % 2
Expand All @@ -382,6 +381,20 @@ def test_linearmodel():
wrong_X = rng.rand(n, n_features, 99)
clf.fit(wrong_X, y)

# check fit_transform call
clf = LinearModel(LinearDiscriminantAnalysis())
_ = clf.fit_transform(X, y)

# check that model has to have coef_, RBF-SVM doesn't
clf = LinearModel(svm.SVC(kernel="rbf"))
with pytest.raises(ValueError, match="does not have a `coef_`"):
clf.fit(X, y)

# check that model has to be a predictor
clf = LinearModel(StandardScaler())
with pytest.raises(ValueError, match="classifier or regressor"):
clf.fit(X, y)

# check categorical target fit in standard linear model with GridSearchCV
parameters = {"kernel": ["linear"], "C": [1, 10]}
clf = LinearModel(
Expand Down Expand Up @@ -478,12 +491,14 @@ def test_cross_val_multiscore():
assert_array_equal(manual, auto)


# XXX Remove the filterwarning after 'model' warning cycle
@pytest.mark.filterwarnings("ignore::FutureWarning")
@parametrize_with_checks([LinearModel(LogisticRegression())])
def test_sklearn_compliance(estimator, check):
"""Test LinearModel compliance with sklearn."""
# XXX Remove the ignores after 'model' warning cycle
ignores = (
"check_estimators_overwrite_params", # self.model changes!
"check_dont_overwrite_parameters",
"check_estimators_overwrite_params",
"check_parameters_default_constructible",
)
if any(ignore in str(check) for ignore in ignores):
Expand Down
7 changes: 6 additions & 1 deletion mne/decoding/tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from sklearn.decomposition import PCA
from sklearn.kernel_ridge import KernelRidge
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
Expand Down Expand Up @@ -229,7 +230,11 @@ def test_vectorizer():
# And that pipelines work properly
X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg"))
vect.fit(X_arr)
clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel())
clf = make_pipeline(
Vectorizer(),
StandardScaler(),
LinearModel(LogisticRegression(solver="liblinear")),
)
clf.fit(X_arr, y)


Expand Down
Loading