Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3ad8c28
initial improvements
Genuster Jul 30, 2025
a3479cf
clone(model) instead of modifying it, change default model, add depre…
Genuster Jul 30, 2025
09dc04f
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 6, 2025
9346f72
make linearmodel work with onevsrestclassifier
Genuster Aug 6, 2025
4e2bcbd
few last fixes
Genuster Aug 6, 2025
568f0ea
add futurewarning test
Genuster Aug 6, 2025
f138d13
add short docstring for predict method
Genuster Aug 6, 2025
838153b
add returns to the predict docstring
Genuster Aug 6, 2025
b7d4c0f
use model.__sklearn__tags__ instead of get_tags
Genuster Aug 6, 2025
9f2354f
use mne's fix for validate_data
Genuster Aug 6, 2025
74177d2
move predict and classes_ back to the wrapped attrs
Genuster Aug 6, 2025
222ed16
add fit_transform test
Genuster Aug 6, 2025
d7ca9ff
__getattr__ can silently catch attribute error from filters_ property
Genuster Aug 6, 2025
175f6c6
make validation compatible with sklearn<1.6
Genuster Aug 7, 2025
40dec84
more old sklearn fixes
Genuster Aug 7, 2025
4cfd798
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 7, 2025
973c898
undo fit check from filters property
Genuster Aug 8, 2025
5e34bd0
add changelog entry
Genuster Aug 8, 2025
452d98b
fix typo in changelog
Genuster Aug 8, 2025
f6b8079
Merge branch 'main' into linearmodel
Genuster Aug 8, 2025
73357a1
undo api changes
Genuster Aug 11, 2025
dd68588
another tiny undo
Genuster Aug 11, 2025
8160516
Merge branch 'main' into linearmodel
Genuster Aug 12, 2025
9397f4e
Merge remote-tracking branch 'upstream/main' into linearmodel
Genuster Aug 18, 2025
a088332
TST: All examples [circle full]
Genuster Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/changes/dev/13361.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
``model`` parameter of :class:`mne.decoding.LinearModel`
will not be modified, use ``model_`` attribute to access the fitted model.
To be compatible with all MNE-Python versions you can use
``getattr(clf, "model_", getattr(clf, "model"))``
The provided ``model`` is expected to be a supervised predictor,
i.e. classifier or regressor (or :class:`sklearn.multiclass.OneVsRestClassifier`),
otherwise an error will be raised.
by `Gennadiy Belonosov`_.
107 changes: 68 additions & 39 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,25 @@
TransformerMixin,
clone,
is_classifier,
is_regressor,
)
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, check_X_y, indexable
from sklearn.utils import indexable
from sklearn.utils.validation import check_is_fitted

from ..parallel import parallel_func
from ..utils import _check_option, _pl, _validate_type, logger, pinv, verbose, warn
from ..utils import (
_check_option,
_pl,
_validate_type,
logger,
pinv,
verbose,
warn,
)
from ._fixes import validate_data
from ._ged import (
_handle_restr_mat,
_is_cov_pos_semidef,
Expand Down Expand Up @@ -340,7 +350,8 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
model : object | None
A linear model from scikit-learn with a fit method
that updates a ``coef_`` attribute.
If None the model will be LogisticRegression.
If None the model will be
:class:`sklearn.linear_model.LogisticRegression`.

Attributes
----------
Expand All @@ -364,46 +375,66 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
.. footbibliography::
"""

# TODO: Properly refactor this using
# https://github.com/scikit-learn/scikit-learn/issues/30237#issuecomment-2465572885
_model_attr_wrap = (
"transform",
"fit_transform",
"predict",
"predict_proba",
"_estimator_type",
"__tags__",
"predict_log_proba",
"_estimator_type", # remove after sklearn 1.6
"decision_function",
"score",
"classes_",
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
if model is None:
model = LogisticRegression(solver="liblinear")
self.model = model

def __sklearn_tags__(self):
"""Get sklearn tags."""
from sklearn.utils import get_tags # added in 1.6

# fit method below does not allow sparse data via check_data, we could
# eventually make it smarter if we had to
tags = get_tags(self.model)
tags.input_tags.sparse = False
tags = super().__sklearn_tags__()
model = self.model if self.model is not None else LogisticRegression()
model_tags = model.__sklearn_tags__()
tags.estimator_type = model_tags.estimator_type
if tags.estimator_type is not None:
model_type_tags = getattr(model_tags, f"{tags.estimator_type}_tags")
setattr(tags, f"{tags.estimator_type}_tags", model_type_tags)
return tags

def __getattr__(self, attr):
"""Wrap to model for some attributes."""
if attr in LinearModel._model_attr_wrap:
return getattr(self.model, attr)
elif attr == "fit_transform" and hasattr(self.model, "fit_transform"):
return super().__getattr__(self, "_fit_transform")
return super().__getattr__(self, attr)
model = self.model_ if "model_" in self.__dict__ else self.model
if attr == "fit_transform" and hasattr(model, "fit_transform"):
return self._fit_transform
else:
return getattr(model, attr)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'"
)

def _fit_transform(self, X, y):
return self.fit(X, y).transform(X)

def _validate_params(self, X):
if self.model is not None:
model = self.model
if isinstance(model, MetaEstimatorMixin):
model = model.estimator
is_predictor = is_regressor(model) or is_classifier(model)
if not is_predictor:
raise ValueError(
"Linear model should be a supervised predictor "
"(classifier or regressor)"
)

# For sklearn < 1.6
try:
self._check_n_features(X, reset=True)
except AttributeError:
pass

def fit(self, X, y, **fit_params):
"""Estimate the coefficients of the linear model.

Expand All @@ -424,25 +455,18 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
raise ValueError(
f"LinearModel only accepts up to 2-dimensional y, got {y.shape} "
"instead."
)
self._validate_params(X)
X, y = validate_data(self, X, y, multi_output=True)

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat
self.model_ = (
clone(self.model)
if self.model is not None
else LogisticRegression(solver="liblinear")
)
self.model_.fit(X, y, **fit_params)

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

inv_Y = 1.0
X = X - X.mean(0, keepdims=True)
if y.ndim == 2 and y.shape[1] != 1:
Expand All @@ -454,12 +478,17 @@ def fit(self, X, y, **fit_params):

@property
def filters_(self):
if hasattr(self.model, "coef_"):
if hasattr(self.model_, "coef_"):
# Standard Linear Model
filters = self.model.coef_
elif hasattr(self.model.best_estimator_, "coef_"):
filters = self.model_.coef_
elif hasattr(self.model_, "estimators_"):
# Linear model with OneVsRestClassifier
filters = np.vstack([est.coef_ for est in self.model_.estimators_])
elif hasattr(self.model_, "best_estimator_") and hasattr(
self.model_.best_estimator_, "coef_"
):
# Linear Model with GridSearchCV
filters = self.model.best_estimator_.coef_
filters = self.model_.best_estimator_.coef_
else:
raise ValueError("model does not have a `coef_` attribute.")
if filters.ndim == 2 and filters.shape[0] == 1:
Expand Down
49 changes: 26 additions & 23 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
is_classifier,
is_regressor,
)
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.model_selection import (
GridSearchCV,
KFold,
StratifiedKFold,
cross_val_score,
)
from sklearn.multiclass import OneVsRestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.utils.estimator_checks import parametrize_with_checks
Expand Down Expand Up @@ -93,12 +95,11 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3):
return X, Y, A


@pytest.mark.filterwarnings("ignore:invalid value encountered in cast.*:RuntimeWarning")
def test_get_coef():
"""Test getting linear coefficients (filters/patterns) from estimators."""
lm_classification = LinearModel()
lm_classification = LinearModel(LogisticRegression(solver="liblinear"))
assert hasattr(lm_classification, "__sklearn_tags__")
if check_version("sklearn", "1.4"):
if check_version("sklearn", "1.6"):
print(lm_classification.__sklearn_tags__())
assert is_classifier(lm_classification.model)
assert is_classifier(lm_classification)
Expand Down Expand Up @@ -200,19 +201,19 @@ def inverse_transform(self, X):
# Retrieve final linear model
filters = get_coef(clf, "filters_", False)
if hasattr(clf, "steps"):
if hasattr(clf.steps[-1][-1].model, "best_estimator_"):
if hasattr(clf.steps[-1][-1].model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.steps[-1][-1].model.best_estimator_.coef_
coefs = clf.steps[-1][-1].model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.steps[-1][-1].model.coef_
coefs = clf.steps[-1][-1].model_.coef_
else:
if hasattr(clf.model, "best_estimator_"):
if hasattr(clf.model_, "best_estimator_"):
# Linear Model with GridSearchCV
coefs = clf.model.best_estimator_.coef_
coefs = clf.model_.best_estimator_.coef_
else:
# Standard Linear Model
coefs = clf.model.coef_
coefs = clf.model_.coef_
if coefs.ndim == 2 and coefs.shape[0] == 1:
coefs = coefs[0]
assert_array_equal(filters, coefs)
Expand Down Expand Up @@ -280,9 +281,7 @@ def test_get_coef_multiclass(n_features, n_targets):
lm = LinearModel(LinearRegression())
assert not hasattr(lm, "model_")
lm.fit(X, Y)
# TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a
# metaestimator?
assert lm.model is lm.model_
assert lm.model is not lm.model_
assert_array_equal(lm.filters_.shape, lm.patterns_.shape)
if n_targets == 1:
want_shape = (n_features,)
Expand Down Expand Up @@ -328,9 +327,6 @@ def test_get_coef_multiclass(n_features, n_targets):
(3, 1, 2),
],
)
# TODO: Need to fix this properly in LinearModel
@pytest.mark.filterwarnings("ignore:'multi_class' was depr.*:FutureWarning")
@pytest.mark.filterwarnings("ignore:lbfgs failed to converge.*:")
def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
"""Test a full example with pattern extraction."""
data = np.zeros((10 * n_classes, n_channels, n_times))
Expand All @@ -345,7 +341,7 @@ def test_get_coef_multiclass_full(n_classes, n_channels, n_times):
clf = make_pipeline(
Scaler(epochs.info),
Vectorizer(),
LinearModel(LogisticRegression(random_state=0, multi_class="ovr")),
LinearModel(OneVsRestClassifier(LogisticRegression(random_state=0))),
)
scorer = "roc_auc_ovr_weighted"
time_gen = GeneralizingEstimator(clf, scorer, verbose=True)
Expand Down Expand Up @@ -382,6 +378,20 @@ def test_linearmodel():
wrong_X = rng.rand(n, n_features, 99)
clf.fit(wrong_X, y)

# check fit_transform call
clf = LinearModel(LinearDiscriminantAnalysis())
_ = clf.fit_transform(X, y)

# check that model has to have coef_, RBF-SVM doesn't
clf = LinearModel(svm.SVC(kernel="rbf"))
with pytest.raises(ValueError, match="does not have a `coef_`"):
clf.fit(X, y)

# check that model has to be a predictor
clf = LinearModel(StandardScaler())
with pytest.raises(ValueError, match="classifier or regressor"):
clf.fit(X, y)

# check categorical target fit in standard linear model with GridSearchCV
parameters = {"kernel": ["linear"], "C": [1, 10]}
clf = LinearModel(
Expand Down Expand Up @@ -481,11 +491,4 @@ def test_cross_val_multiscore():
@parametrize_with_checks([LinearModel(LogisticRegression())])
def test_sklearn_compliance(estimator, check):
"""Test LinearModel compliance with sklearn."""
ignores = (
"check_estimators_overwrite_params", # self.model changes!
"check_dont_overwrite_parameters",
"check_parameters_default_constructible",
)
if any(ignore in str(check) for ignore in ignores):
return
check(estimator)
Loading