diff --git a/doc/_includes/ged.rst b/doc/_includes/ged.rst
index 8f5fc17131c..5146fef5ffa 100644
--- a/doc/_includes/ged.rst
+++ b/doc/_includes/ged.rst
@@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of
Generalized Eigendecomposition (GED), often used in spatial filtering
and source separation algorithms, such as :class:`mne.decoding.CSP`,
:class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and
-:class:`mne.preprocessing.Xdawn`.
+:class:`mne.decoding.XdawnTransformer`.
The core principle of GED is to find a set of channel weights (spatial filter)
that maximizes the ratio of signal power between two data features.
diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst
index 788a62b42da..f8f2257825f 100644
--- a/doc/api/decoding.rst
+++ b/doc/api/decoding.rst
@@ -30,6 +30,7 @@ Decoding
SPoC
SSD
XdawnTransformer
+ SpatialFilter
Functions that assist with decoding and model fitting:
@@ -39,3 +40,4 @@ Functions that assist with decoding and model fitting:
compute_ems
cross_val_multiscore
get_coef
+ get_spatial_filter_from_estimator
diff --git a/doc/changes/dev/13332.newfeature.rst b/doc/changes/dev/13332.newfeature.rst
new file mode 100644
index 00000000000..018dfbb9094
--- /dev/null
+++ b/doc/changes/dev/13332.newfeature.rst
@@ -0,0 +1,4 @@
+Implement :class:`mne.decoding.SpatialFilter` class returned by :func:`mne.decoding.get_spatial_filter_from_estimator` for
+visualisation of filters and patterns for :class:`mne.decoding.LinearModel`
+and additionally eigenvalues for GED-based transformers such as
+:class:`mne.decoding.XdawnTransformer`, :class:`mne.decoding.CSP`, by `Gennadiy Belonosov`_.
\ No newline at end of file
diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py
index 5859edde166..758c674e16e 100644
--- a/examples/decoding/decoding_csp_eeg.py
+++ b/examples/decoding/decoding_csp_eeg.py
@@ -14,6 +14,7 @@
`PhysioNet documentation page `_.
The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`.
"""
+
# Authors: Martin Billinger
#
# License: BSD-3-Clause
@@ -30,7 +31,7 @@
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
-from mne.decoding import CSP
+from mne.decoding import CSP, get_spatial_filter_from_estimator
from mne.io import concatenate_raws, read_raw_edf
print(__doc__)
@@ -95,10 +96,11 @@
class_balance = max(class_balance, 1.0 - class_balance)
print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")
-# plot CSP patterns estimated on full data for visualization
+# plot eigenvalues and patterns estimated on full data for visualization
csp.fit_transform(epochs_data, labels)
-
-csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)
+spf = get_spatial_filter_from_estimator(csp, info=epochs.info)
+spf.plot_scree()
+spf.plot_patterns(components=np.arange(4))
# %%
# Look at performance over time
diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py
index 2f85138d2b3..3accd5b2cd6 100644
--- a/examples/decoding/decoding_spoc_CMC.py
+++ b/examples/decoding/decoding_spoc_CMC.py
@@ -32,7 +32,7 @@
import mne
from mne import Epochs
from mne.datasets.fieldtrip_cmc import data_path
-from mne.decoding import SPoC
+from mne.decoding import SPoC, get_spatial_filter_from_estimator
# Define parameters
fname = data_path() / "SubjectCMC.ds"
@@ -82,9 +82,18 @@
# Plot the contributions to the detected components (i.e., the forward model)
spoc.fit(X, y)
-spoc.plot_patterns(meg_epochs.info)
+spf = get_spatial_filter_from_estimator(spoc, info=meg_epochs.info)
+spf.plot_scree()
+
+# Plot patterns for the first three components
+# with largest absolute generalized eigenvalues,
+# as we can see on the scree plot
+spf.plot_patterns(components=[0, 1, 2])
+
##############################################################################
# References
# ----------
# .. footbibliography::
+
+# %%
diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py
index ab274963f31..1d1bf3f8760 100644
--- a/examples/decoding/decoding_xdawn_eeg.py
+++ b/examples/decoding/decoding_xdawn_eeg.py
@@ -10,6 +10,7 @@
Channels are concatenated and rescaled to create features vectors that will be
fed into a logistic regression.
"""
+
# Authors: Alexandre Barachant
#
# License: BSD-3-Clause
@@ -26,10 +27,9 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler
-from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events
+from mne import Epochs, io, pick_types, read_events
from mne.datasets import sample
-from mne.decoding import Vectorizer
-from mne.preprocessing import Xdawn
+from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator
print(__doc__)
@@ -71,31 +71,33 @@
# Create classification pipeline
clf = make_pipeline(
- Xdawn(n_components=n_filter),
+ XdawnTransformer(n_components=n_filter),
Vectorizer(),
MinMaxScaler(),
OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")),
)
-# Get the labels
-labels = epochs.events[:, -1]
+# Get the data and labels
+# X is of shape (n_epochs, n_channels, n_times)
+X = epochs.get_data(copy=False)
+y = epochs.events[:, -1]
# Cross validator
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
# Do cross-validation
-preds = np.empty(len(labels))
-for train, test in cv.split(epochs, labels):
- clf.fit(epochs[train], labels[train])
- preds[test] = clf.predict(epochs[test])
+preds = np.empty(len(y))
+for train, test in cv.split(epochs, y):
+ clf.fit(X[train], y[train])
+ preds[test] = clf.predict(X[test])
# Classification report
target_names = ["aud_l", "aud_r", "vis_l", "vis_r"]
-report = classification_report(labels, preds, target_names=target_names)
+report = classification_report(y, preds, target_names=target_names)
print(report)
# Normalized confusion matrix
-cm = confusion_matrix(labels, preds)
+cm = confusion_matrix(y, preds)
cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]
# Plot confusion matrix
@@ -109,30 +111,35 @@
ax.set(ylabel="True label", xlabel="Predicted label")
# %%
-# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last
-# cross-validation fold) can be used for visualization.
-
-fig, axes = plt.subplots(
- nrows=len(event_id),
- ncols=n_filter,
- figsize=(n_filter, len(event_id) * 2),
- layout="constrained",
+# Patterns of a fitted XdawnTransformer instance (here from the last
+# cross-validation fold) can be visualized using SpatialFilter container.
+
+# Instantiate SpatialFilter
+spf = get_spatial_filter_from_estimator(
+ clf, info=epochs.info, step_name="xdawntransformer"
+)
+
+# Let's first examine the scree plot of generalized eigenvalues
+# for each class.
+spf.plot_scree(title="")
+
+# We can see that for all four classes ~five largest components
+# capture most of the variance, let's plot their patterns.
+# Each class will now return its own figure
+components_to_plot = np.arange(5)
+figs = spf.plot_patterns(
+ # Indices of patterns to plot,
+ # we will plot the first three for each class
+ components=components_to_plot,
+ show=False, # to set the titles below
)
-fitted_xdawn = clf.steps[0][1]
-info = create_info(epochs.ch_names, 1, epochs.get_channel_types())
-info.set_montage(epochs.get_montage())
-for ii, cur_class in enumerate(sorted(event_id)):
- cur_patterns = fitted_xdawn.patterns_[cur_class]
- pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0)
- pattern_evoked.plot_topomap(
- times=np.arange(n_filter),
- time_format="Component %d" if ii == 0 else "",
- colorbar=False,
- show_names=False,
- axes=axes[ii],
- show=False,
- )
- axes[ii, 0].set(ylabel=cur_class)
+
+# Set the class titles
+event_id_reversed = {v: k for k, v in event_id.items()}
+for fig, class_idx in zip(figs, clf[0].classes_):
+ class_name = event_id_reversed[class_idx]
+ fig.suptitle(class_name, fontsize=16)
+
# %%
# References
diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py
index 7373c0a18b3..48d679ed1fd 100644
--- a/examples/decoding/linear_model_patterns.py
+++ b/examples/decoding/linear_model_patterns.py
@@ -14,6 +14,7 @@
Note patterns/filters in MEG data are more similar than EEG data
because the noise is less spatially correlated in MEG than EEG.
"""
+
# Authors: Alexandre Gramfort
# Romain Trachel
# Jean-Rémi King
@@ -28,11 +29,16 @@
from sklearn.preprocessing import StandardScaler
import mne
-from mne import EvokedArray, io
+from mne import io
from mne.datasets import sample
# import a linear classifier from mne.decoding
-from mne.decoding import LinearModel, Vectorizer, get_coef
+from mne.decoding import (
+ LinearModel,
+ SpatialFilter,
+ Vectorizer,
+ get_spatial_filter_from_estimator,
+)
print(__doc__)
@@ -77,7 +83,7 @@
X = scaler.fit_transform(meg_data)
model.fit(X, labels)
-# Extract and plot spatial filters and spatial patterns
+coefs = dict()
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
@@ -85,12 +91,30 @@
# The data was vectorized to fit a single model across all time points and
# all channels. We thus reshape it:
- coef = coef.reshape(len(meg_epochs.ch_names), -1)
-
- # Plot
- evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin)
- fig = evoked.plot_topomap()
- fig.suptitle(f"MEG {name}")
+ coefs[name] = coef.reshape(len(meg_epochs.ch_names), -1).T
+
+# Now we can instantiate the visualization container
+spf = SpatialFilter(info=meg_epochs.info, **coefs)
+fig = spf.plot_patterns(
+ # we will automatically select patterns
+ components="auto",
+ # as our filters and patterns correspond to actual times
+ # we can align them
+ tmin=epochs.tmin,
+ units="fT", # it's physical - we inversed the scaling
+ show=False, # to set the title below
+ name_format=None, # to plot actual times
+)
+fig.suptitle("MEG patterns")
+# Same for filters
+fig = spf.plot_filters(
+ components="auto",
+ tmin=epochs.tmin,
+ units="fT",
+ show=False,
+ name_format=None,
+)
+fig.suptitle("MEG filters")
# %%
# Let's do the same on EEG data using a scikit-learn pipeline
@@ -107,15 +131,26 @@
),
)
clf.fit(X, y)
-
-# Extract and plot patterns and filters
-for name in ("patterns_", "filters_"):
- # The `inverse_transform` parameter will call this method on any estimator
- # contained in the pipeline, in reverse order.
- coef = get_coef(clf, name, inverse_transform=True)
- evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin)
- fig = evoked.plot_topomap()
- fig.suptitle(f"EEG {name[:-1]}")
+spf = get_spatial_filter_from_estimator(
+ clf, info=epochs.info, inverse_transform=True, step_name="linearmodel"
+)
+fig = spf.plot_patterns(
+ components="auto",
+ tmin=epochs.tmin,
+ units="uV",
+ show=False,
+ name_format=None,
+)
+fig.suptitle("EEG patterns")
+# Same for filters
+fig = spf.plot_filters(
+ components="auto",
+ tmin=epochs.tmin,
+ units="uV",
+ show=False,
+ name_format=None,
+)
+fig.suptitle("EEG filters")
# %%
# References
diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py
index e9ca9ba79cf..7938fe6ad2a 100644
--- a/examples/decoding/ssd_spatial_filters.py
+++ b/examples/decoding/ssd_spatial_filters.py
@@ -26,7 +26,7 @@
import mne
from mne import Epochs
from mne.datasets.fieldtrip_cmc import data_path
-from mne.decoding import SSD
+from mne.decoding import SSD, get_spatial_filter_from_estimator
# %%
# Define parameters
@@ -70,8 +70,8 @@
# (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t.
# We rely on the inversion approach here.
-pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info)
-pattern.plot_topomap(units=dict(mag="A.U."), time_format="")
+spf = get_spatial_filter_from_estimator(ssd, info=ssd.info)
+spf.plot_patterns(components=list(range(4)))
# The topographies suggest that we picked up a parietal alpha generator.
@@ -150,8 +150,8 @@
ssd_epochs.fit(X=epochs.get_data(copy=False))
# Plot topographies.
-pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info)
-pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="")
+spf = get_spatial_filter_from_estimator(ssd_epochs, info=ssd_epochs.info)
+spf.plot_patterns(components=list(range(4)))
# %%
# References
# ----------
diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi
index 6a1e7d8ab89..1131f1597c5 100644
--- a/mne/decoding/__init__.pyi
+++ b/mne/decoding/__init__.pyi
@@ -11,6 +11,7 @@ __all__ = [
"SSD",
"Scaler",
"SlidingEstimator",
+ "SpatialFilter",
"TemporalFilter",
"TimeDelayingRidge",
"TimeFrequency",
@@ -21,6 +22,7 @@ __all__ = [
"compute_ems",
"cross_val_multiscore",
"get_coef",
+ "get_spatial_filter_from_estimator",
]
from .base import (
BaseEstimator,
@@ -33,6 +35,7 @@ from .csp import CSP, SPoC
from .ems import EMS, compute_ems
from .receptive_field import ReceptiveField
from .search_light import GeneralizingEstimator, SlidingEstimator
+from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
from .ssd import SSD
from .time_delaying_ridge import TimeDelayingRidge
from .time_frequency import TimeFrequency
diff --git a/mne/decoding/base.py b/mne/decoding/base.py
index adae374ea25..1fd64a5b22b 100644
--- a/mne/decoding/base.py
+++ b/mne/decoding/base.py
@@ -44,7 +44,7 @@
_smart_ged,
)
from ._mod_ged import _no_op_mod
-from .transformer import MNETransformerMixin
+from .transformer import MNETransformerMixin, Vectorizer
class _GEDTransformer(MNETransformerMixin, BaseEstimator):
@@ -585,8 +585,32 @@ def _get_inverse_funcs(estimator, terminal=True):
return inverse_func
+def _get_inverse_funcs_before_step(estimator, step_name):
+ """Get the inverse_transform methods for all steps before a target step."""
+ # in case step_name is nested with __
+ parts = step_name.split("__")
+ inverse_funcs = list()
+ current_pipeline = estimator
+ for part_name in parts:
+ all_names = [name for name, _ in current_pipeline.steps]
+ part_idx = all_names.index(part_name)
+ # get all preceding steps for the current step
+ for prec_name, prec_step in current_pipeline.steps[:part_idx]:
+ if hasattr(prec_step, "inverse_transform"):
+ inverse_funcs.append(prec_step.inverse_transform)
+ else:
+ warn(
+ f"Preceding step '{prec_name}' is not invertible "
+ f"and will be skipped."
+ )
+ current_pipeline = current_pipeline.named_steps[part_name]
+ return inverse_funcs
+
+
@verbose
-def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=None):
+def get_coef(
+ estimator, attr="filters_", inverse_transform=False, *, step_name=None, verbose=None
+):
"""Retrieve the coefficients of an estimator ending with a Linear Model.
This is typically useful to retrieve "spatial filters" or "spatial
@@ -602,6 +626,13 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
inverse_transform : bool
If True, returns the coefficients after inverse transforming them with
the transformer steps of the estimator.
+ step_name : str | None
+ Name of the sklearn's pipeline step to get the coef from.
+ If inverse_transform is True, the inverse transformations
+ will be applied using transformers before this step.
+ If None, the last step will be used. Defaults to None.
+
+ .. versionadded:: 1.11
%(verbose)s
Returns
@@ -616,8 +647,17 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
# Get the coefficients of the last estimator in case of nested pipeline
est = estimator
logger.debug(f"Getting coefficients from estimator: {est.__class__.__name__}")
- while hasattr(est, "steps"):
- est = est.steps[-1][1]
+
+ if step_name is not None:
+ if not hasattr(estimator, "named_steps"):
+ raise ValueError("step_name can only be used with a pipeline estimator.")
+ try:
+ est = est.get_params(deep=True)[step_name]
+ except KeyError:
+ raise ValueError(f"Step '{step_name}' is not part of the pipeline.")
+ else:
+ while hasattr(est, "steps"):
+ est = est.steps[-1][1]
squeeze_first_dim = False
@@ -646,9 +686,14 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
raise ValueError(
"inverse_transform can only be applied onto pipeline estimators."
)
+ if step_name is None:
+ inverse_funcs = _get_inverse_funcs(estimator)
+ else:
+ inverse_funcs = _get_inverse_funcs_before_step(estimator, step_name)
+
# The inverse_transform parameter will call this method on any
# estimator contained in the pipeline, in reverse order.
- for inverse_func in _get_inverse_funcs(estimator)[::-1]:
+ for inverse_func in inverse_funcs[::-1]:
logger.debug(f" Applying inverse transformation: {inverse_func}.")
coef = inverse_func(coef)
@@ -656,6 +701,17 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non
logger.debug(" Squeezing first dimension of coefficients.")
coef = coef[0]
+ # inverse_transform with Vectorizer returns shape (n_channels, n_components).
+ # we should transpose to be consistent with how spatial filters
+ # store filters and patterns: (n_components, n_channels)
+ if inverse_transform and hasattr(estimator, "steps"):
+ is_vectorizer = any(
+ isinstance(param_value, Vectorizer)
+ for param_value in estimator.get_params(deep=True).values()
+ )
+ if is_vectorizer and coef.ndim == 2:
+ coef = coef.T
+
return coef
diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py
index be20b968f07..4071aa00ee0 100644
--- a/mne/decoding/csp.py
+++ b/mne/decoding/csp.py
@@ -3,22 +3,17 @@
# Copyright the MNE-Python contributors.
import collections.abc as abc
-import copy as cp
from functools import partial
import numpy as np
from .._fiff.meas_info import Info
from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
-from ..evoked import EvokedArray
-from ..utils import (
- _check_option,
- _validate_type,
- fill_doc,
-)
+from ..utils import _check_option, _validate_type, fill_doc, legacy
from ._covs_ged import _csp_estimate, _spoc_estimate
from ._mod_ged import _csp_mod, _spoc_mod
from .base import _GEDTransformer
+from .spatial_filter import get_spatial_filter_from_estimator
@fill_doc
@@ -316,6 +311,7 @@ def fit_transform(self, X, y=None, **fit_params):
# use parent TransformerMixin method but with custom docstring
return super().fit_transform(X, y=y, **fit_params)
+ @legacy(alt="get_spatial_filter_from_estimator(clf, info=info).plot_patterns()")
@fill_doc
def plot_patterns(
self,
@@ -402,20 +398,9 @@ def plot_patterns(
fig : instance of matplotlib.figure.Figure
The figure.
"""
- if units is None:
- units = "AU"
- if components is None:
- components = np.arange(self.n_components)
-
- # set sampling frequency to have 1 component per time point
- info = cp.deepcopy(info)
- with info._unlock():
- info["sfreq"] = 1.0
- # create an evoked
- patterns = EvokedArray(self.patterns_.T, info, tmin=0)
- # the call plot_topomap
- fig = patterns.plot_topomap(
- times=components,
+ spf = get_spatial_filter_from_estimator(self, info=info)
+ return spf.plot_patterns(
+ components,
ch_type=ch_type,
scalings=scalings,
sensors=sensors,
@@ -437,13 +422,13 @@ def plot_patterns(
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
- time_format=name_format,
+ name_format=name_format,
nrows=nrows,
ncols=ncols,
show=show,
)
- return fig
+ @legacy(alt="get_spatial_filter_from_estimator(clf, info=info).plot_filters()")
@fill_doc
def plot_filters(
self,
@@ -530,20 +515,9 @@ def plot_filters(
fig : instance of matplotlib.figure.Figure
The figure.
"""
- if units is None:
- units = "AU"
- if components is None:
- components = np.arange(self.n_components)
-
- # set sampling frequency to have 1 component per time point
- info = cp.deepcopy(info)
- with info._unlock():
- info["sfreq"] = 1.0
- # create an evoked
- filters = EvokedArray(self.filters_.T, info, tmin=0)
- # the call plot_topomap
- fig = filters.plot_topomap(
- times=components,
+ spf = get_spatial_filter_from_estimator(self, info=info)
+ return spf.plot_filters(
+ components,
ch_type=ch_type,
scalings=scalings,
sensors=sensors,
@@ -565,12 +539,11 @@ def plot_filters(
cbar_fmt=cbar_fmt,
units=units,
axes=axes,
- time_format=name_format,
+ name_format=name_format,
nrows=nrows,
ncols=ncols,
show=show,
)
- return fig
def _ajd_pham(X, eps=1e-6, max_iter=15):
diff --git a/mne/decoding/spatial_filter.py b/mne/decoding/spatial_filter.py
new file mode 100644
index 00000000000..169cca7d005
--- /dev/null
+++ b/mne/decoding/spatial_filter.py
@@ -0,0 +1,639 @@
+# Authors: The MNE-Python contributors.
+# License: BSD-3-Clause
+# Copyright the MNE-Python contributors.
+
+import copy as cp
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT
+from ..evoked import EvokedArray
+from ..utils import _check_option, fill_doc, verbose
+from ..viz.utils import plt_show
+from .base import LinearModel, _GEDTransformer, get_coef
+
+
+def _plot_model(
+ model_array,
+ info,
+ components=None,
+ *,
+ evk_tmin=None,
+ ch_type=None,
+ scalings=None,
+ sensors=True,
+ show_names=False,
+ mask=None,
+ mask_params=None,
+ contours=6,
+ outlines="head",
+ sphere=None,
+ image_interp=_INTERPOLATION_DEFAULT,
+ extrapolate=_EXTRAPOLATE_DEFAULT,
+ border=_BORDER_DEFAULT,
+ res=64,
+ size=1,
+ cmap="RdBu_r",
+ vlim=(None, None),
+ cnorm=None,
+ colorbar=True,
+ cbar_fmt="%3.1f",
+ units=None,
+ axes=None,
+ name_format=None,
+ nrows=1,
+ ncols="auto",
+ show=True,
+):
+ if components is None:
+ n_comps = model_array.shape[-2]
+ components = np.arange(n_comps)
+ kwargs = dict(
+ # args set here
+ times=components,
+ average=None,
+ proj=False,
+ units="AU" if units is None else units,
+ time_format=name_format,
+ # args passed from the upstream
+ ch_type=ch_type,
+ scalings=scalings,
+ sensors=sensors,
+ show_names=show_names,
+ mask=mask,
+ mask_params=mask_params,
+ contours=contours,
+ outlines=outlines,
+ sphere=sphere,
+ image_interp=image_interp,
+ extrapolate=extrapolate,
+ border=border,
+ res=res,
+ size=size,
+ cmap=cmap,
+ vlim=vlim,
+ cnorm=cnorm,
+ colorbar=colorbar,
+ cbar_fmt=cbar_fmt,
+ nrows=nrows,
+ ncols=ncols,
+ show=show,
+ )
+
+ # set sampling frequency to have 1 component per time point
+
+ if evk_tmin is None:
+ info = cp.deepcopy(info)
+ with info._unlock():
+ info["sfreq"] = 1.0
+ evk_tmin = 0
+
+ if model_array.ndim == 3:
+ n_classes = model_array.shape[0]
+ figs = list()
+ for class_idx in range(n_classes):
+ model_evk = EvokedArray(model_array[class_idx].T, info, tmin=evk_tmin)
+ fig = model_evk.plot_topomap(
+ axes=axes[class_idx] if axes else None, **kwargs
+ )
+ figs.append(fig)
+ return figs
+ else:
+ model_evk = EvokedArray(model_array.T, info, tmin=evk_tmin)
+ fig = model_evk.plot_topomap(axes=axes, **kwargs)
+ return fig
+
+
+def _plot_scree_per_class(evals, add_cumul_evals, axes):
+ component_numbers = np.arange(len(evals))
+ cumul_evals = np.cumsum(evals) if add_cumul_evals else None
+ # plot individual eigenvalues
+ color_line = "cornflowerblue"
+ axes.set_xlabel("Component Index", fontsize=18)
+ axes.set_ylabel("Eigenvalue", fontsize=18)
+ axes.plot(
+ component_numbers,
+ evals,
+ color=color_line,
+ marker="o",
+ markersize=8,
+ )
+ axes.tick_params(axis="y", labelsize=16)
+ axes.tick_params(axis="x", labelsize=16)
+
+ if add_cumul_evals:
+ # plot cumulative eigenvalues
+ ax2 = axes.twinx()
+ ax2.grid(False)
+ color_line = "firebrick"
+ ax2.set_ylabel("Cumulative Eigenvalues", fontsize=18)
+ ax2.plot(
+ component_numbers,
+ cumul_evals,
+ color=color_line,
+ marker="o",
+ markersize=6,
+ )
+ ax2.tick_params(axis="y", labelcolor=color_line, labelsize=16)
+ ax2.set_ylim(0)
+
+
+def _plot_scree(
+ evals,
+ title="Scree plot",
+ add_cumul_evals=True,
+ axes=None,
+):
+ evals_data = evals if evals.ndim == 2 else [evals]
+ n_classes = len(evals_data)
+ axes = [axes] if isinstance(axes, plt.Axes) else axes
+ if axes is not None and n_classes != len(axes):
+ raise ValueError(f"Received {len(axes)} axes, but expected {n_classes}")
+
+ orig_axes = axes
+ figs = list()
+ for class_idx in range(n_classes):
+ fig = None
+ if orig_axes is None:
+ fig, ax = plt.subplots(figsize=(7, 4), layout="constrained")
+ else:
+ ax = axes[class_idx]
+ _plot_scree_per_class(evals_data[class_idx], add_cumul_evals, ax)
+ if fig is not None:
+ fig.suptitle(title, fontsize=22)
+ figs.append(fig)
+
+ return figs[0] if len(figs) == 1 else figs
+
+
+@verbose
+def get_spatial_filter_from_estimator(
+ estimator,
+ info,
+ *,
+ inverse_transform=False,
+ step_name=None,
+ get_coefs=("filters_", "patterns_", "evals_"),
+ patterns_method=None,
+ verbose=None,
+):
+ """Instantiate a :class:`mne.decoding.SpatialFilter` object.
+
+ Creates object from the fitted generalized eigendecomposition
+ transformers or :class:`mne.decoding.LinearModel`.
+ This object can be used to visualize spatial filters,
+ patterns, and eigenvalues.
+
+ Parameters
+ ----------
+ estimator : instance of sklearn.base.BaseEstimator
+ Sklearn-based estimator or meta-estimator from which to initialize
+ spatial filter. Use ``step_name`` to select relevant transformer
+ from the pipeline object (works with nested names using ``__`` syntax).
+ info : instance of mne.Info
+ The measurement info object for plotting topomaps.
+ inverse_transform : bool
+ If True, returns filters and patterns after inverse transforming them with
+ the transformer steps of the estimator. Defaults to False.
+ step_name : str | None
+ Name of the sklearn's pipeline step to get the coefs from.
+ If inverse_transform is True, the inverse transformations
+ will be applied using transformers before this step.
+ If None, the last step will be used. Defaults to None.
+ get_coefs : tuple
+ The names of the coefficient attributes to retrieve, can include
+ ``'filters_'``, ``'patterns_'`` and ``'evals_'``.
+ If step is GEDTransformer, will use all.
+ if step is LinearModel will only use ``'filters_'`` and ``'patterns_'``.
+ Defaults to (``'filters_'``, ``'patterns_'``, ``'evals_'``).
+ patterns_method : str
+ The method used to compute the patterns. Can be None, ``'pinv'`` or ``'haufe'``.
+ It will be set automatically to ``'pinv'`` if step is GEDTransformer,
+ or to ``'haufe'`` if step is LinearModel. Defaults to None.
+ %(verbose)s
+
+ Returns
+ -------
+ sp_filter : instance of mne.decoding.SpatialFilter
+ The spatial filter object.
+
+ See Also
+ --------
+ SpatialFilter, mne.decoding.LinearModel, mne.decoding.CSP,
+ mne.decoding.SSD, mne.decoding.XdawnTransformer, mne.decoding.SPoC
+
+ Notes
+ -----
+ .. versionadded:: 1.11
+ """
+ for coef in get_coefs:
+ if coef not in ("filters_", "patterns_", "evals_"):
+ raise ValueError(
+ f"'get_coefs' can only include 'filters_', "
+ f"'patterns_' and 'evals_', but got {coef}."
+ )
+ if step_name is not None:
+ model = estimator.get_params()[step_name]
+ elif hasattr(estimator, "named_steps"):
+ model = estimator[-1]
+ else:
+ model = estimator
+ if isinstance(model, LinearModel):
+ patterns_method = "haufe"
+ get_coefs = ["filters_", "patterns_"]
+ elif isinstance(model, _GEDTransformer):
+ patterns_method = "pinv"
+ get_coefs = ["filters_", "patterns_", "evals_"]
+
+ coefs = {
+ coef[:-1]: get_coef(
+ estimator,
+ coef,
+ inverse_transform=False if coef == "evals_" else inverse_transform,
+ step_name=step_name,
+ verbose=verbose,
+ )
+ for coef in get_coefs
+ }
+
+ sp_filter = SpatialFilter(info, patterns_method=patterns_method, **coefs)
+ return sp_filter
+
+
+class SpatialFilter:
+ r"""Container for spatial filter weights (evecs) and patterns.
+
+ .. warning:: For MNE-Python decoding classes, this container should be
+ instantiated with `mne.decoding.get_spatial_filter_from_estimator`.
+ Direct instantiation with external spatial filters is possible
+ at your own risk.
+
+ This object is obtained either by generalized eigendecomposition (GED) algorithms
+ such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`,
+ :class:`mne.decoding.SSD`, :class:`mne.decoding.XdawnTransformer` or by
+ :class:`mne.decoding.LinearModel`, wrapping linear models like SVM or Logit.
+ The object stores the filters that projects sensor data to a reduced component
+ space, and the corresponding patterns (obtained by pseudoinverse in GED case or
+ Haufe's trick in case of :class:`mne.decoding.LinearModel`). It can also be directly
+ initialized using filters from other transformers (e.g. PyRiemann),
+ but make sure that the dimensions match.
+
+ Parameters
+ ----------
+ info : instance of Info
+ The measurement info containing channel topography.
+ filters : ndarray, shape ((n_classes), n_components, n_channels)
+ The spatial filters (transposed eigenvectors of the decomposition).
+ evals : ndarray, shape ((n_classes), n_components) | None
+ The eigenvalues of the decomposition. Defaults to ``None``.
+ patterns : ndarray, shape ((n_classes), n_components, n_channels) | None
+ The patterns of the decomposition. If None, they will be computed
+ from the filters using pseudoinverse. Defaults to ``None``.
+ patterns_method : str
+ The method used to compute the patterns. Can be ``'pinv'`` or ``'haufe'``.
+ If ``patterns`` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``.
+
+ Attributes
+ ----------
+ info : instance of Info
+ The measurement info.
+ filters : ndarray, shape (n_components, n_channels)
+ The spatial filters (unmixing matrix). Applying these filters to the data
+ gives the component time series.
+ patterns : ndarray, shape (n_components, n_channels)
+ The spatial patterns (mixing matrix/forward model).
+ These represent the scalp topography of each component.
+ evals : ndarray, shape (n_components,)
+ The eigenvalues associated with each component.
+ patterns_method : str
+ The method used to compute the patterns from the filters.
+
+ See Also
+ --------
+ get_spatial_filter_from_estimator, mne.decoding.LinearModel, mne.decoding.CSP,
+ mne.decoding.SSD, mne.decoding.XdawnTransformer, mne.decoding.SPoC
+
+ Notes
+ -----
+ The spatial filters and patterns are stored with shape
+ ``(n_components, n_channels)``.
+
+ Filters and patterns are related by the following equation:
+
+ .. math::
+ \mathbf{A} = \mathbf{W}^{-1}
+
+ where :math:`\mathbf{A}` is the matrix of patterns (the mixing matrix) and
+ :math:`\mathbf{W}` is the matrix of filters (the unmixing matrix).
+
+ For a detailed discussion on the difference between filters and patterns for GED
+ see :footcite:`Cohen2022` and for linear models in
+ general see :footcite:`HaufeEtAl2014`.
+
+ .. versionadded:: 1.11
+
+ References
+ ----------
+ .. footbibliography::
+ """
+
+ def __init__(
+ self,
+ info,
+ filters,
+ *,
+ evals=None,
+ patterns=None,
+ patterns_method="pinv",
+ ):
+ _check_option(
+ "patterns_method",
+ patterns_method,
+ ("pinv", "haufe"),
+ )
+ self.info = info
+ self.evals = evals
+ self.filters = filters
+ n_comps, n_chs = self.filters.shape[-2:]
+ if patterns is None:
+ # XXX Using numpy's pinv here to handle 3D case seamlessly
+ # Perhaps mne.linalg.pinv can be improved to handle 3D also
+ # Then it could be changed here to be consistent with
+ # GEDTransformer
+ self.patterns = np.linalg.pinv(filters.T)
+ self.patterns_method = "pinv"
+ else:
+ self.patterns = patterns
+ self.patterns_method = patterns_method
+
+ # In case of multi-target classification in LinearModel
+ # number of targets can be greater than number of channels.
+ if patterns_method != "haufe" and n_comps > n_chs:
+ raise ValueError(
+ "Number of components can't be greater "
+ "than number of channels in filters, "
+ "perhaps the provided matrix is transposed?"
+ )
+ if self.filters.shape != self.patterns.shape:
+ raise ValueError(
+ f"Shape mismatch between filters and patterns."
+ f"Filters are {self.filters.shape},"
+ f"while patterns are {self.patterns.shape}"
+ )
+
+ @fill_doc
+ def plot_filters(
+ self,
+ components=None,
+ tmin=None,
+ *,
+ ch_type=None,
+ scalings=None,
+ sensors=True,
+ show_names=False,
+ mask=None,
+ mask_params=None,
+ contours=6,
+ outlines="head",
+ sphere=None,
+ image_interp=_INTERPOLATION_DEFAULT,
+ extrapolate=_EXTRAPOLATE_DEFAULT,
+ border=_BORDER_DEFAULT,
+ res=64,
+ size=1,
+ cmap="RdBu_r",
+ vlim=(None, None),
+ cnorm=None,
+ colorbar=True,
+ cbar_fmt="%3.1f",
+ units=None,
+ axes=None,
+ name_format="Filter%01d",
+ nrows=1,
+ ncols="auto",
+ show=True,
+ ):
+ """Plot topographic maps of model filters.
+
+ Parameters
+ ----------
+ components : float | array of float | 'auto' | None
+ Indices of filters to plot. If "auto", the number of
+ ``axes`` determines the amount of filters.
+ If None, all filters will be plotted. Defaults to None.
+ tmin : float | None
+ In case filters are distributed temporally,
+ this can be used to align them with times
+ and frequency. Use ``epochs.tmin``, for example.
+ Defaults to None.
+ %(ch_type_topomap)s
+ %(scalings_topomap)s
+ %(sensors_topomap)s
+ %(show_names_topomap)s
+ %(mask_evoked_topomap)s
+ %(mask_params_topomap)s
+ %(contours_topomap)s
+ %(outlines_topomap)s
+ %(sphere_topomap_auto)s
+ %(image_interp_topomap)s
+ %(extrapolate_topomap)s
+ %(border_topomap)s
+ %(res_topomap)s
+ %(size_topomap)s
+ %(cmap_topomap)s
+ %(vlim_plot_topomap_psd)s
+ %(cnorm)s
+ %(colorbar_topomap)s
+ %(cbar_fmt_topomap)s
+ %(units_topomap_evoked)s
+ %(axes_evoked_plot_topomap)s
+ name_format : str
+ String format for topomap values. Defaults to ``'Filter%%01d'``.
+ %(nrows_ncols_topomap)s
+ %(show)s
+
+ Returns
+ -------
+ fig : instance of matplotlib.figure.Figure
+ The figure.
+ """
+ fig = _plot_model(
+ self.filters,
+ self.info,
+ components=components,
+ evk_tmin=tmin,
+ ch_type=ch_type,
+ scalings=scalings,
+ sensors=sensors,
+ show_names=show_names,
+ mask=mask,
+ mask_params=mask_params,
+ contours=contours,
+ outlines=outlines,
+ sphere=sphere,
+ image_interp=image_interp,
+ extrapolate=extrapolate,
+ border=border,
+ res=res,
+ size=size,
+ cmap=cmap,
+ vlim=vlim,
+ cnorm=cnorm,
+ colorbar=colorbar,
+ cbar_fmt=cbar_fmt,
+ units=units,
+ axes=axes,
+ name_format=name_format,
+ nrows=nrows,
+ ncols=ncols,
+ show=show,
+ )
+ return fig
+
+ @fill_doc
+ def plot_patterns(
+ self,
+ components=None,
+ tmin=None,
+ *,
+ ch_type=None,
+ scalings=None,
+ sensors=True,
+ show_names=False,
+ mask=None,
+ mask_params=None,
+ contours=6,
+ outlines="head",
+ sphere=None,
+ image_interp=_INTERPOLATION_DEFAULT,
+ extrapolate=_EXTRAPOLATE_DEFAULT,
+ border=_BORDER_DEFAULT,
+ res=64,
+ size=1,
+ cmap="RdBu_r",
+ vlim=(None, None),
+ cnorm=None,
+ colorbar=True,
+ cbar_fmt="%3.1f",
+ units=None,
+ axes=None,
+ name_format="Pattern%01d",
+ nrows=1,
+ ncols="auto",
+ show=True,
+ ):
+ """Plot topographic maps of model patterns.
+
+ Parameters
+ ----------
+ components : float | array of float | 'auto' | None
+ Indices of patterns to plot. If "auto", the number of
+ ``axes`` determines the amount of patterns.
+ If None, all patterns will be plotted. Defaults to None.
+ tmin : float | None
+ In case patterns are distributed temporally,
+ this can be used to align them with times
+ and frequency. Use ``epochs.tmin``, for example.
+ Defaults to None.
+ %(ch_type_topomap)s
+ %(scalings_topomap)s
+ %(sensors_topomap)s
+ %(show_names_topomap)s
+ %(mask_evoked_topomap)s
+ %(mask_params_topomap)s
+ %(contours_topomap)s
+ %(outlines_topomap)s
+ %(sphere_topomap_auto)s
+ %(image_interp_topomap)s
+ %(extrapolate_topomap)s
+ %(border_topomap)s
+ %(res_topomap)s
+ %(size_topomap)s
+ %(cmap_topomap)s
+ %(vlim_plot_topomap_psd)s
+ %(cnorm)s
+ %(colorbar_topomap)s
+ %(cbar_fmt_topomap)s
+ %(units_topomap_evoked)s
+ %(axes_evoked_plot_topomap)s
+ name_format : str
+ String format for topomap values. Defaults to ``'Pattern%%01d'``.
+ %(nrows_ncols_topomap)s
+ %(show)s
+
+ Returns
+ -------
+ fig : instance of matplotlib.figure.Figure
+ The figure.
+ """
+ fig = _plot_model(
+ self.patterns,
+ self.info,
+ components=components,
+ evk_tmin=tmin,
+ ch_type=ch_type,
+ scalings=scalings,
+ sensors=sensors,
+ show_names=show_names,
+ mask=mask,
+ mask_params=mask_params,
+ contours=contours,
+ outlines=outlines,
+ sphere=sphere,
+ image_interp=image_interp,
+ extrapolate=extrapolate,
+ border=border,
+ res=res,
+ size=size,
+ cmap=cmap,
+ vlim=vlim,
+ cnorm=cnorm,
+ colorbar=colorbar,
+ cbar_fmt=cbar_fmt,
+ units=units,
+ axes=axes,
+ name_format=name_format,
+ nrows=nrows,
+ ncols=ncols,
+ show=show,
+ )
+ return fig
+
+ @fill_doc
+ def plot_scree(
+ self,
+ title="Scree plot",
+ add_cumul_evals=False,
+ axes=None,
+ show=True,
+ ):
+ """Plot scree for GED eigenvalues.
+
+ Parameters
+ ----------
+ title : str
+ Title for the plot. Defaults to ``'Scree plot'``.
+ add_cumul_evals : bool
+ Whether to add second line and y-axis for cumulative eigenvalues.
+ Defaults to ``True``.
+ axes : instance of Axes | None
+ The matplotlib axes to plot to. Defaults to ``None``.
+ %(show)s
+
+ Returns
+ -------
+ fig : instance of matplotlib.figure.Figure
+ The figure.
+ """
+ if self.evals is None:
+ raise AttributeError("Can't plot scree if eigenvalues are not provided.")
+
+ fig = _plot_scree(
+ self.evals,
+ title=title,
+ add_cumul_evals=add_cumul_evals,
+ axes=axes,
+ )
+ plt_show(show, block=False)
+ return fig
diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py
index a41b3246ed2..68623876222 100644
--- a/mne/decoding/tests/test_base.py
+++ b/mne/decoding/tests/test_base.py
@@ -28,6 +28,7 @@
is_classifier,
is_regressor,
)
+from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge
from sklearn.model_selection import (
@@ -241,35 +242,133 @@ def transform(self, X):
@pytest.mark.parametrize("inverse", (True, False))
@pytest.mark.parametrize(
- "Scale, kwargs",
+ "clf",
[
- (Scaler, dict(info=None, scalings="mean")),
- (_Noop, dict()),
+ pytest.param(
+ make_pipeline(
+ Scaler(info=None, scalings="mean"),
+ SlidingEstimator(make_pipeline(LinearModel(Ridge()))),
+ ),
+ id="Scaler+SlidingEstimator",
+ ),
+ pytest.param(
+ make_pipeline(
+ _Noop(),
+ SlidingEstimator(make_pipeline(LinearModel(Ridge()))),
+ ),
+ id="Noop+SlidingEstimator",
+ ),
+ pytest.param(
+ SlidingEstimator(make_pipeline(StandardScaler(), LinearModel(Ridge()))),
+ id="SlidingEstimator+nested StandardScaler",
+ ),
],
)
-def test_get_coef_inverse_transform(inverse, Scale, kwargs):
+def test_get_coef_inverse_transform(inverse, clf):
"""Test get_coef with and without inverse_transform."""
- lm_regression = LinearModel(Ridge())
X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1)
- # Check with search_light and combination of preprocessing ending with sl:
- # slider = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression))
- # XXX : line above should work but does not as only last step is
- # used in get_coef ...
- slider = SlidingEstimator(make_pipeline(lm_regression))
X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples
- clf = make_pipeline(Scale(**kwargs), slider)
clf.fit(X, y)
patterns = get_coef(clf, "patterns_", inverse)
filters = get_coef(clf, "filters_", inverse)
assert_array_equal(filters.shape, patterns.shape, X.shape[1:])
# the two time samples get inverted patterns
assert_equal(patterns[0, 0], -patterns[0, 1])
+
for t in [0, 1]:
- filters_t = get_coef(
- clf.named_steps["slidingestimator"].estimators_[t], "filters_", False
+ if hasattr(clf, "named_steps"):
+ est_t = clf.named_steps["slidingestimator"].estimators_[t]
+ filters_t = get_coef(est_t, "filters_", inverse)
+ if inverse:
+ filters_t = clf[0].inverse_transform(filters_t.reshape(1, -1))[0]
+ else:
+ est_t = clf.estimators_[t]
+ filters_t = get_coef(est_t, "filters_", inverse)
+
+ assert_equal(filters_t, filters[:, t])
+
+
+def test_get_coef_inverse_step_name():
+ """Test get_coef with inverse_transform=True and a specific step_name."""
+ X, y, _ = _make_data(n_samples=100, n_features=5, n_targets=1)
+
+ # Test with a simple pipeline
+ pipe = make_pipeline(StandardScaler(), PCA(n_components=3), LinearModel(Ridge()))
+ pipe.fit(X, y)
+
+ coef_inv_actual = get_coef(
+ pipe, attr="patterns_", inverse_transform=True, step_name="linearmodel"
+ )
+ # Reshape your data using array.reshape(1, -1) if it contains a single sample.
+ coef_raw = pipe.named_steps["linearmodel"].patterns_.reshape(1, -1)
+ coef_inv_desired = pipe.named_steps["pca"].inverse_transform(coef_raw)
+ coef_inv_desired = pipe.named_steps["standardscaler"].inverse_transform(
+ coef_inv_desired
+ )
+
+ assert coef_inv_actual.shape == (X.shape[1],)
+ # Reshape your data using array.reshape(1, -1) if it contains a single sample.
+ assert_array_almost_equal(coef_inv_actual.reshape(1, -1), coef_inv_desired)
+
+ with pytest.raises(ValueError, match="inverse_transform"):
+ _ = get_coef(
+ pipe[-1], # LinearModel
+ "filters_",
+ inverse_transform=True,
+ )
+ with pytest.raises(ValueError, match="step_name"):
+ _ = get_coef(
+ SlidingEstimator(pipe),
+ "filters_",
+ inverse_transform=True,
+ step_name="slidingestimator__pipeline__linearmodel",
+ )
+
+ # Test with a nested pipeline to check __ parsing
+ inner_pipe = make_pipeline(PCA(n_components=3), LinearModel(Ridge()))
+ nested_pipe = make_pipeline(StandardScaler(), inner_pipe)
+ nested_pipe.fit(X, y)
+ coef_nested_inv_actual = get_coef(
+ nested_pipe,
+ attr="patterns_",
+ inverse_transform=True,
+ step_name="pipeline__linearmodel",
+ )
+ linearmodel = nested_pipe.named_steps["pipeline"].named_steps["linearmodel"]
+ pca = nested_pipe.named_steps["pipeline"].named_steps["pca"]
+ scaler = nested_pipe.named_steps["standardscaler"]
+
+ coef_nested_raw = linearmodel.patterns_.reshape(1, -1)
+ coef_nested_inv_desired = pca.inverse_transform(coef_nested_raw)
+ coef_nested_inv_desired = scaler.inverse_transform(coef_nested_inv_desired)
+
+ assert coef_nested_inv_actual.shape == (X.shape[1],)
+ assert_array_almost_equal(
+ coef_nested_inv_actual.reshape(1, -1), coef_nested_inv_desired
+ )
+
+ with pytest.raises(ValueError, match="i_do_not_exist"):
+ get_coef(
+ pipe, attr="patterns_", inverse_transform=True, step_name="i_do_not_exist"
+ )
+
+ class NonInvertibleTransformer(BaseEstimator, TransformerMixin):
+ def fit(self, X, y=None):
+ return self
+
+ def transform(self, X):
+ # In a real scenario, this would modify X
+ return X
+
+ pipe = make_pipeline(NonInvertibleTransformer(), LinearModel(Ridge()))
+ pipe.fit(X, y)
+ with pytest.warns(RuntimeWarning, match="not invertible"):
+ _ = get_coef(
+ pipe,
+ "filters_",
+ inverse_transform=True,
+ step_name="linearmodel",
)
- if Scale is _Noop:
- assert_array_equal(filters_t, filters[:, t])
@pytest.mark.parametrize("n_features", [1, 5])
@@ -311,7 +410,15 @@ def test_get_coef_multiclass(n_features, n_targets):
if n_features > 1 and n_targets > 1:
assert_allclose(A, lm.patterns_.T, atol=2e-2)
coef = get_coef(clf, "patterns_", inverse_transform=True)
- lm_patterns_ = lm.patterns_[..., np.newaxis]
+
+ lm_patterns_ = lm.patterns_
+ # Expected shape is (n_targets, n_features)
+ # which is equivalent to (n_components, n_channels)
+ # in spatial filters
+ if lm_patterns_.ndim == 1:
+ lm_patterns_ = lm_patterns_[np.newaxis, :]
+ else:
+ lm_patterns_ = lm_patterns_[..., np.newaxis]
assert_allclose(lm_patterns_, coef, atol=1e-5)
# Check can pass fitting parameters
diff --git a/mne/decoding/tests/test_spatial_filter.py b/mne/decoding/tests/test_spatial_filter.py
new file mode 100644
index 00000000000..385b73fc053
--- /dev/null
+++ b/mne/decoding/tests/test_spatial_filter.py
@@ -0,0 +1,190 @@
+# Authors: The MNE-Python contributors.
+# License: BSD-3-Clause
+# Copyright the MNE-Python contributors.
+
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pytest
+from numpy.testing import assert_array_equal
+
+pytest.importorskip("sklearn")
+
+from sklearn.linear_model import LinearRegression
+from sklearn.pipeline import make_pipeline
+from sklearn.preprocessing import StandardScaler
+
+from mne import Epochs, create_info, io, pick_types, read_events
+from mne.decoding import (
+ CSP,
+ LinearModel,
+ SpatialFilter,
+ Vectorizer,
+ XdawnTransformer,
+ get_spatial_filter_from_estimator,
+)
+
+data_dir = Path(__file__).parents[2] / "io" / "tests" / "data"
+raw_fname = data_dir / "test_raw.fif"
+event_name = data_dir / "test-eve.fif"
+tmin, tmax = -0.1, 0.2
+event_id = dict(aud_l=1, vis_l=3)
+start, stop = 0, 8
+
+
+def _get_X_y(event_id, return_info=False):
+ raw = io.read_raw(raw_fname, preload=False)
+ events = read_events(event_name)
+ picks = pick_types(
+ raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads"
+ )
+ picks = picks[2:12:3] # subselect channels -> disable proj!
+ raw.add_proj([], remove_existing=True)
+ epochs = Epochs(
+ raw,
+ events,
+ event_id,
+ tmin,
+ tmax,
+ picks=picks,
+ baseline=(None, 0),
+ preload=True,
+ proj=False,
+ )
+ X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT"))
+ y = epochs.events[:, -1]
+ if return_info:
+ return X, y, epochs.info
+ return X, y
+
+
+def test_spatial_filter_init():
+ """Test the initialization of the SpatialFilter class."""
+ # Test initialization and factory function
+ rng = np.random.RandomState(0)
+ n, n_features = 20, 3
+ X = rng.rand(n, n_features)
+ n_targets = 5
+ y = rng.rand(n, n_targets)
+ clf = LinearModel(LinearRegression())
+ clf.fit(X, y)
+
+ # test get_spatial_filter_from_estimator for LinearModel
+ info = create_info(n_features, 1000.0, "eeg")
+ sp_filter = get_spatial_filter_from_estimator(clf, info)
+ assert sp_filter.patterns_method == "haufe"
+ assert_array_equal(sp_filter.filters, clf.filters_)
+ assert_array_equal(sp_filter.patterns, clf.patterns_)
+ assert sp_filter.evals is None
+
+ with pytest.raises(ValueError, match="can only include"):
+ _ = get_spatial_filter_from_estimator(
+ clf, info, get_coefs=("foo", "foo", "foo")
+ )
+
+ event_id = dict(aud_l=1, vis_l=3)
+ X, y, info = _get_X_y(event_id, return_info=True)
+ estimator = make_pipeline(Vectorizer(), StandardScaler(), CSP(n_components=4))
+ estimator.fit(X, y)
+ csp = estimator[-1]
+ # test get_spatial_filter_from_estimator for GED
+ sp_filter = get_spatial_filter_from_estimator(estimator, info, step_name="csp")
+ assert sp_filter.patterns_method == "pinv"
+ assert_array_equal(sp_filter.filters, csp.filters_)
+ assert_array_equal(sp_filter.patterns, csp.patterns_)
+ assert_array_equal(sp_filter.evals, csp.evals_)
+ assert sp_filter.info is info
+
+ # test without step_name
+ sp_filter = get_spatial_filter_from_estimator(estimator, info)
+ assert_array_equal(sp_filter.filters, csp.filters_)
+ assert_array_equal(sp_filter.patterns, csp.patterns_)
+ assert_array_equal(sp_filter.evals, csp.evals_)
+
+ # test basic initialization
+ sp_filter = SpatialFilter(
+ info, filters=csp.filters_, patterns=csp.patterns_, evals=csp.evals_
+ )
+ assert_array_equal(sp_filter.filters, csp.filters_)
+ assert_array_equal(sp_filter.patterns, csp.patterns_)
+ assert_array_equal(sp_filter.evals, csp.evals_)
+ assert sp_filter.info is info
+
+ # test automatic pattern calculation via pinv
+ sp_filter_pinv = SpatialFilter(info, filters=csp.filters_, evals=csp.evals_)
+ patterns_pinv = np.linalg.pinv(csp.filters_.T)
+ assert_array_equal(sp_filter_pinv.patterns, patterns_pinv)
+ assert sp_filter_pinv.patterns_method == "pinv"
+
+ # test shape mismatch error
+ with pytest.raises(ValueError, match="Shape mismatch"):
+ SpatialFilter(info, filters=csp.filters_, patterns=csp.patterns_[:-1])
+
+ # test invalid patterns_method
+ with pytest.raises(ValueError, match="patterns_method"):
+ SpatialFilter(info, filters=csp.filters_, patterns_method="foo")
+
+ # test n_components > n_channels error
+ bad_filters = np.random.randn(31, 30) # 31 components, 30 channels
+ with pytest.raises(ValueError, match="Number of components can't be greater"):
+ SpatialFilter(info, filters=bad_filters)
+
+
+def test_spatial_filter_plotting():
+ """Test the plotting methods of SpatialFilter."""
+ event_id = dict(aud_l=1, vis_l=3)
+ X, y, info = _get_X_y(event_id, return_info=True)
+ csp = CSP(n_components=4)
+ csp.fit(X, y)
+
+ sp_filter = get_spatial_filter_from_estimator(csp, info)
+
+ # test plot_filters
+ fig_filters = sp_filter.plot_filters(components=[0, 1], show=False)
+ assert isinstance(fig_filters, plt.Figure)
+ plt.close("all")
+
+ # test plot_patterns
+ fig_patterns = sp_filter.plot_patterns(show=False)
+ assert isinstance(fig_patterns, plt.Figure)
+ plt.close("all")
+
+ # test plot_scree
+ fig_scree = sp_filter.plot_scree(show=False, add_cumul_evals=True)
+ assert isinstance(fig_scree, plt.Figure)
+ plt.close("all")
+ _, axes = plt.subplots(figsize=(12, 7), layout="constrained")
+ fig_scree = sp_filter.plot_scree(axes=axes, show=False)
+ assert fig_scree == list()
+ plt.close("all")
+
+ # test plot_scree raises error if evals is None
+ sp_filter_no_evals = SpatialFilter(info, filters=csp.filters_, evals=None)
+ with pytest.raises(AttributeError, match="eigenvalues are not provided"):
+ sp_filter_no_evals.plot_scree()
+
+ # 3D case ('multi' GED decomposition)
+ n_classes = 2
+ event_id = dict(aud_l=1, vis_l=3)
+ X, y, info = _get_X_y(event_id, return_info=True)
+ xdawn = XdawnTransformer(n_components=4)
+ xdawn.fit(X, y)
+ sp_filter = get_spatial_filter_from_estimator(xdawn, info)
+
+ fig_patterns = sp_filter.plot_patterns(show=False)
+ assert len(fig_patterns) == n_classes
+ plt.close("all")
+
+ fig_scree = sp_filter.plot_scree(show=False)
+ assert len(fig_scree) == n_classes
+ plt.close("all")
+
+ with pytest.raises(ValueError, match="but expected"):
+ _, axes = plt.subplots(figsize=(12, 7), layout="constrained")
+ _ = sp_filter.plot_scree(axes=axes, show=False)
+
+ _, axes = plt.subplots(n_classes, figsize=(12, 7), layout="constrained")
+ fig_scree = sp_filter.plot_scree(axes=axes, show=False)
+ assert fig_scree == list()
+ plt.close("all")
diff --git a/tutorials/machine-learning/50_decoding.py b/tutorials/machine-learning/50_decoding.py
index c2a56ce0555..30f1204598d 100644
--- a/tutorials/machine-learning/50_decoding.py
+++ b/tutorials/machine-learning/50_decoding.py
@@ -47,6 +47,7 @@
Vectorizer,
cross_val_multiscore,
get_coef,
+ get_spatial_filter_from_estimator,
)
data_path = sample.data_path()
@@ -285,12 +286,16 @@
# This is also called the mixing matrix. The example :ref:`ex-linear-patterns`
# discusses the difference between patterns and filters.
#
-# These can be plotted with:
+# These can be plotted for every spatial filter including CSP, XdawnTransformer,
+# SSD and SPoC:
-# Fit CSP on full data and plot
+# Fit CSP on full data, plot eigenvalues sorted based on mutual information,
+# and plot patterns and filters for the three components largest components.
csp.fit(X, y)
-csp.plot_patterns(epochs.info)
-csp.plot_filters(epochs.info, scalings=1e-9)
+spf = get_spatial_filter_from_estimator(csp, info=epochs.info)
+spf.plot_scree()
+spf.plot_patterns(components=[0, 1, 2])
+spf.plot_filters(components=[0, 1, 2], scalings=1e-9)
# %%
# Decoding over time