diff --git a/doc/_includes/ged.rst b/doc/_includes/ged.rst index 8f5fc17131c..5146fef5ffa 100644 --- a/doc/_includes/ged.rst +++ b/doc/_includes/ged.rst @@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of Generalized Eigendecomposition (GED), often used in spatial filtering and source separation algorithms, such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and -:class:`mne.preprocessing.Xdawn`. +:class:`mne.decoding.XdawnTransformer`. The core principle of GED is to find a set of channel weights (spatial filter) that maximizes the ratio of signal power between two data features. diff --git a/doc/api/decoding.rst b/doc/api/decoding.rst index 788a62b42da..f8f2257825f 100644 --- a/doc/api/decoding.rst +++ b/doc/api/decoding.rst @@ -30,6 +30,7 @@ Decoding SPoC SSD XdawnTransformer + SpatialFilter Functions that assist with decoding and model fitting: @@ -39,3 +40,4 @@ Functions that assist with decoding and model fitting: compute_ems cross_val_multiscore get_coef + get_spatial_filter_from_estimator diff --git a/doc/changes/dev/13332.newfeature.rst b/doc/changes/dev/13332.newfeature.rst new file mode 100644 index 00000000000..018dfbb9094 --- /dev/null +++ b/doc/changes/dev/13332.newfeature.rst @@ -0,0 +1,4 @@ +Implement :class:`mne.decoding.SpatialFilter` class returned by :func:`mne.decoding.get_spatial_filter_from_estimator` for +visualisation of filters and patterns for :class:`mne.decoding.LinearModel` +and additionally eigenvalues for GED-based transformers such as +:class:`mne.decoding.XdawnTransformer`, :class:`mne.decoding.CSP`, by `Gennadiy Belonosov`_. \ No newline at end of file diff --git a/examples/decoding/decoding_csp_eeg.py b/examples/decoding/decoding_csp_eeg.py index 5859edde166..758c674e16e 100644 --- a/examples/decoding/decoding_csp_eeg.py +++ b/examples/decoding/decoding_csp_eeg.py @@ -14,6 +14,7 @@ `PhysioNet documentation page `_. The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`. """ + # Authors: Martin Billinger # # License: BSD-3-Clause @@ -30,7 +31,7 @@ from mne import Epochs, pick_types from mne.channels import make_standard_montage from mne.datasets import eegbci -from mne.decoding import CSP +from mne.decoding import CSP, get_spatial_filter_from_estimator from mne.io import concatenate_raws, read_raw_edf print(__doc__) @@ -95,10 +96,11 @@ class_balance = max(class_balance, 1.0 - class_balance) print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}") -# plot CSP patterns estimated on full data for visualization +# plot eigenvalues and patterns estimated on full data for visualization csp.fit_transform(epochs_data, labels) - -csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5) +spf = get_spatial_filter_from_estimator(csp, info=epochs.info) +spf.plot_scree() +spf.plot_patterns(components=np.arange(4)) # %% # Look at performance over time diff --git a/examples/decoding/decoding_spoc_CMC.py b/examples/decoding/decoding_spoc_CMC.py index 2f85138d2b3..3accd5b2cd6 100644 --- a/examples/decoding/decoding_spoc_CMC.py +++ b/examples/decoding/decoding_spoc_CMC.py @@ -32,7 +32,7 @@ import mne from mne import Epochs from mne.datasets.fieldtrip_cmc import data_path -from mne.decoding import SPoC +from mne.decoding import SPoC, get_spatial_filter_from_estimator # Define parameters fname = data_path() / "SubjectCMC.ds" @@ -82,9 +82,18 @@ # Plot the contributions to the detected components (i.e., the forward model) spoc.fit(X, y) -spoc.plot_patterns(meg_epochs.info) +spf = get_spatial_filter_from_estimator(spoc, info=meg_epochs.info) +spf.plot_scree() + +# Plot patterns for the first three components +# with largest absolute generalized eigenvalues, +# as we can see on the scree plot +spf.plot_patterns(components=[0, 1, 2]) + ############################################################################## # References # ---------- # .. footbibliography:: + +# %% diff --git a/examples/decoding/decoding_xdawn_eeg.py b/examples/decoding/decoding_xdawn_eeg.py index ab274963f31..1d1bf3f8760 100644 --- a/examples/decoding/decoding_xdawn_eeg.py +++ b/examples/decoding/decoding_xdawn_eeg.py @@ -10,6 +10,7 @@ Channels are concatenated and rescaled to create features vectors that will be fed into a logistic regression. """ + # Authors: Alexandre Barachant # # License: BSD-3-Clause @@ -26,10 +27,9 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import MinMaxScaler -from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events +from mne import Epochs, io, pick_types, read_events from mne.datasets import sample -from mne.decoding import Vectorizer -from mne.preprocessing import Xdawn +from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator print(__doc__) @@ -71,31 +71,33 @@ # Create classification pipeline clf = make_pipeline( - Xdawn(n_components=n_filter), + XdawnTransformer(n_components=n_filter), Vectorizer(), MinMaxScaler(), OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")), ) -# Get the labels -labels = epochs.events[:, -1] +# Get the data and labels +# X is of shape (n_epochs, n_channels, n_times) +X = epochs.get_data(copy=False) +y = epochs.events[:, -1] # Cross validator cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) # Do cross-validation -preds = np.empty(len(labels)) -for train, test in cv.split(epochs, labels): - clf.fit(epochs[train], labels[train]) - preds[test] = clf.predict(epochs[test]) +preds = np.empty(len(y)) +for train, test in cv.split(epochs, y): + clf.fit(X[train], y[train]) + preds[test] = clf.predict(X[test]) # Classification report target_names = ["aud_l", "aud_r", "vis_l", "vis_r"] -report = classification_report(labels, preds, target_names=target_names) +report = classification_report(y, preds, target_names=target_names) print(report) # Normalized confusion matrix -cm = confusion_matrix(labels, preds) +cm = confusion_matrix(y, preds) cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis] # Plot confusion matrix @@ -109,30 +111,35 @@ ax.set(ylabel="True label", xlabel="Predicted label") # %% -# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last -# cross-validation fold) can be used for visualization. - -fig, axes = plt.subplots( - nrows=len(event_id), - ncols=n_filter, - figsize=(n_filter, len(event_id) * 2), - layout="constrained", +# Patterns of a fitted XdawnTransformer instance (here from the last +# cross-validation fold) can be visualized using SpatialFilter container. + +# Instantiate SpatialFilter +spf = get_spatial_filter_from_estimator( + clf, info=epochs.info, step_name="xdawntransformer" +) + +# Let's first examine the scree plot of generalized eigenvalues +# for each class. +spf.plot_scree(title="") + +# We can see that for all four classes ~five largest components +# capture most of the variance, let's plot their patterns. +# Each class will now return its own figure +components_to_plot = np.arange(5) +figs = spf.plot_patterns( + # Indices of patterns to plot, + # we will plot the first three for each class + components=components_to_plot, + show=False, # to set the titles below ) -fitted_xdawn = clf.steps[0][1] -info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) -info.set_montage(epochs.get_montage()) -for ii, cur_class in enumerate(sorted(event_id)): - cur_patterns = fitted_xdawn.patterns_[cur_class] - pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0) - pattern_evoked.plot_topomap( - times=np.arange(n_filter), - time_format="Component %d" if ii == 0 else "", - colorbar=False, - show_names=False, - axes=axes[ii], - show=False, - ) - axes[ii, 0].set(ylabel=cur_class) + +# Set the class titles +event_id_reversed = {v: k for k, v in event_id.items()} +for fig, class_idx in zip(figs, clf[0].classes_): + class_name = event_id_reversed[class_idx] + fig.suptitle(class_name, fontsize=16) + # %% # References diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index 7373c0a18b3..48d679ed1fd 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -14,6 +14,7 @@ Note patterns/filters in MEG data are more similar than EEG data because the noise is less spatially correlated in MEG than EEG. """ + # Authors: Alexandre Gramfort # Romain Trachel # Jean-Rémi King @@ -28,11 +29,16 @@ from sklearn.preprocessing import StandardScaler import mne -from mne import EvokedArray, io +from mne import io from mne.datasets import sample # import a linear classifier from mne.decoding -from mne.decoding import LinearModel, Vectorizer, get_coef +from mne.decoding import ( + LinearModel, + SpatialFilter, + Vectorizer, + get_spatial_filter_from_estimator, +) print(__doc__) @@ -77,7 +83,7 @@ X = scaler.fit_transform(meg_data) model.fit(X, labels) -# Extract and plot spatial filters and spatial patterns +coefs = dict() for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): # We fit the linear model on Z-scored data. To make the filters # interpretable, we must reverse this normalization step @@ -85,12 +91,30 @@ # The data was vectorized to fit a single model across all time points and # all channels. We thus reshape it: - coef = coef.reshape(len(meg_epochs.ch_names), -1) - - # Plot - evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin) - fig = evoked.plot_topomap() - fig.suptitle(f"MEG {name}") + coefs[name] = coef.reshape(len(meg_epochs.ch_names), -1).T + +# Now we can instantiate the visualization container +spf = SpatialFilter(info=meg_epochs.info, **coefs) +fig = spf.plot_patterns( + # we will automatically select patterns + components="auto", + # as our filters and patterns correspond to actual times + # we can align them + tmin=epochs.tmin, + units="fT", # it's physical - we inversed the scaling + show=False, # to set the title below + name_format=None, # to plot actual times +) +fig.suptitle("MEG patterns") +# Same for filters +fig = spf.plot_filters( + components="auto", + tmin=epochs.tmin, + units="fT", + show=False, + name_format=None, +) +fig.suptitle("MEG filters") # %% # Let's do the same on EEG data using a scikit-learn pipeline @@ -107,15 +131,26 @@ ), ) clf.fit(X, y) - -# Extract and plot patterns and filters -for name in ("patterns_", "filters_"): - # The `inverse_transform` parameter will call this method on any estimator - # contained in the pipeline, in reverse order. - coef = get_coef(clf, name, inverse_transform=True) - evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin) - fig = evoked.plot_topomap() - fig.suptitle(f"EEG {name[:-1]}") +spf = get_spatial_filter_from_estimator( + clf, info=epochs.info, inverse_transform=True, step_name="linearmodel" +) +fig = spf.plot_patterns( + components="auto", + tmin=epochs.tmin, + units="uV", + show=False, + name_format=None, +) +fig.suptitle("EEG patterns") +# Same for filters +fig = spf.plot_filters( + components="auto", + tmin=epochs.tmin, + units="uV", + show=False, + name_format=None, +) +fig.suptitle("EEG filters") # %% # References diff --git a/examples/decoding/ssd_spatial_filters.py b/examples/decoding/ssd_spatial_filters.py index e9ca9ba79cf..7938fe6ad2a 100644 --- a/examples/decoding/ssd_spatial_filters.py +++ b/examples/decoding/ssd_spatial_filters.py @@ -26,7 +26,7 @@ import mne from mne import Epochs from mne.datasets.fieldtrip_cmc import data_path -from mne.decoding import SSD +from mne.decoding import SSD, get_spatial_filter_from_estimator # %% # Define parameters @@ -70,8 +70,8 @@ # (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t. # We rely on the inversion approach here. -pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info) -pattern.plot_topomap(units=dict(mag="A.U."), time_format="") +spf = get_spatial_filter_from_estimator(ssd, info=ssd.info) +spf.plot_patterns(components=list(range(4))) # The topographies suggest that we picked up a parietal alpha generator. @@ -150,8 +150,8 @@ ssd_epochs.fit(X=epochs.get_data(copy=False)) # Plot topographies. -pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info) -pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="") +spf = get_spatial_filter_from_estimator(ssd_epochs, info=ssd_epochs.info) +spf.plot_patterns(components=list(range(4))) # %% # References # ---------- diff --git a/mne/decoding/__init__.pyi b/mne/decoding/__init__.pyi index 6a1e7d8ab89..1131f1597c5 100644 --- a/mne/decoding/__init__.pyi +++ b/mne/decoding/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "SSD", "Scaler", "SlidingEstimator", + "SpatialFilter", "TemporalFilter", "TimeDelayingRidge", "TimeFrequency", @@ -21,6 +22,7 @@ __all__ = [ "compute_ems", "cross_val_multiscore", "get_coef", + "get_spatial_filter_from_estimator", ] from .base import ( BaseEstimator, @@ -33,6 +35,7 @@ from .csp import CSP, SPoC from .ems import EMS, compute_ems from .receptive_field import ReceptiveField from .search_light import GeneralizingEstimator, SlidingEstimator +from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator from .ssd import SSD from .time_delaying_ridge import TimeDelayingRidge from .time_frequency import TimeFrequency diff --git a/mne/decoding/base.py b/mne/decoding/base.py index adae374ea25..1fd64a5b22b 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -44,7 +44,7 @@ _smart_ged, ) from ._mod_ged import _no_op_mod -from .transformer import MNETransformerMixin +from .transformer import MNETransformerMixin, Vectorizer class _GEDTransformer(MNETransformerMixin, BaseEstimator): @@ -585,8 +585,32 @@ def _get_inverse_funcs(estimator, terminal=True): return inverse_func +def _get_inverse_funcs_before_step(estimator, step_name): + """Get the inverse_transform methods for all steps before a target step.""" + # in case step_name is nested with __ + parts = step_name.split("__") + inverse_funcs = list() + current_pipeline = estimator + for part_name in parts: + all_names = [name for name, _ in current_pipeline.steps] + part_idx = all_names.index(part_name) + # get all preceding steps for the current step + for prec_name, prec_step in current_pipeline.steps[:part_idx]: + if hasattr(prec_step, "inverse_transform"): + inverse_funcs.append(prec_step.inverse_transform) + else: + warn( + f"Preceding step '{prec_name}' is not invertible " + f"and will be skipped." + ) + current_pipeline = current_pipeline.named_steps[part_name] + return inverse_funcs + + @verbose -def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=None): +def get_coef( + estimator, attr="filters_", inverse_transform=False, *, step_name=None, verbose=None +): """Retrieve the coefficients of an estimator ending with a Linear Model. This is typically useful to retrieve "spatial filters" or "spatial @@ -602,6 +626,13 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non inverse_transform : bool If True, returns the coefficients after inverse transforming them with the transformer steps of the estimator. + step_name : str | None + Name of the sklearn's pipeline step to get the coef from. + If inverse_transform is True, the inverse transformations + will be applied using transformers before this step. + If None, the last step will be used. Defaults to None. + + .. versionadded:: 1.11 %(verbose)s Returns @@ -616,8 +647,17 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non # Get the coefficients of the last estimator in case of nested pipeline est = estimator logger.debug(f"Getting coefficients from estimator: {est.__class__.__name__}") - while hasattr(est, "steps"): - est = est.steps[-1][1] + + if step_name is not None: + if not hasattr(estimator, "named_steps"): + raise ValueError("step_name can only be used with a pipeline estimator.") + try: + est = est.get_params(deep=True)[step_name] + except KeyError: + raise ValueError(f"Step '{step_name}' is not part of the pipeline.") + else: + while hasattr(est, "steps"): + est = est.steps[-1][1] squeeze_first_dim = False @@ -646,9 +686,14 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non raise ValueError( "inverse_transform can only be applied onto pipeline estimators." ) + if step_name is None: + inverse_funcs = _get_inverse_funcs(estimator) + else: + inverse_funcs = _get_inverse_funcs_before_step(estimator, step_name) + # The inverse_transform parameter will call this method on any # estimator contained in the pipeline, in reverse order. - for inverse_func in _get_inverse_funcs(estimator)[::-1]: + for inverse_func in inverse_funcs[::-1]: logger.debug(f" Applying inverse transformation: {inverse_func}.") coef = inverse_func(coef) @@ -656,6 +701,17 @@ def get_coef(estimator, attr="filters_", inverse_transform=False, *, verbose=Non logger.debug(" Squeezing first dimension of coefficients.") coef = coef[0] + # inverse_transform with Vectorizer returns shape (n_channels, n_components). + # we should transpose to be consistent with how spatial filters + # store filters and patterns: (n_components, n_channels) + if inverse_transform and hasattr(estimator, "steps"): + is_vectorizer = any( + isinstance(param_value, Vectorizer) + for param_value in estimator.get_params(deep=True).values() + ) + if is_vectorizer and coef.ndim == 2: + coef = coef.T + return coef diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index be20b968f07..4071aa00ee0 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -3,22 +3,17 @@ # Copyright the MNE-Python contributors. import collections.abc as abc -import copy as cp from functools import partial import numpy as np from .._fiff.meas_info import Info from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT -from ..evoked import EvokedArray -from ..utils import ( - _check_option, - _validate_type, - fill_doc, -) +from ..utils import _check_option, _validate_type, fill_doc, legacy from ._covs_ged import _csp_estimate, _spoc_estimate from ._mod_ged import _csp_mod, _spoc_mod from .base import _GEDTransformer +from .spatial_filter import get_spatial_filter_from_estimator @fill_doc @@ -316,6 +311,7 @@ def fit_transform(self, X, y=None, **fit_params): # use parent TransformerMixin method but with custom docstring return super().fit_transform(X, y=y, **fit_params) + @legacy(alt="get_spatial_filter_from_estimator(clf, info=info).plot_patterns()") @fill_doc def plot_patterns( self, @@ -402,20 +398,9 @@ def plot_patterns( fig : instance of matplotlib.figure.Figure The figure. """ - if units is None: - units = "AU" - if components is None: - components = np.arange(self.n_components) - - # set sampling frequency to have 1 component per time point - info = cp.deepcopy(info) - with info._unlock(): - info["sfreq"] = 1.0 - # create an evoked - patterns = EvokedArray(self.patterns_.T, info, tmin=0) - # the call plot_topomap - fig = patterns.plot_topomap( - times=components, + spf = get_spatial_filter_from_estimator(self, info=info) + return spf.plot_patterns( + components, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -437,13 +422,13 @@ def plot_patterns( cbar_fmt=cbar_fmt, units=units, axes=axes, - time_format=name_format, + name_format=name_format, nrows=nrows, ncols=ncols, show=show, ) - return fig + @legacy(alt="get_spatial_filter_from_estimator(clf, info=info).plot_filters()") @fill_doc def plot_filters( self, @@ -530,20 +515,9 @@ def plot_filters( fig : instance of matplotlib.figure.Figure The figure. """ - if units is None: - units = "AU" - if components is None: - components = np.arange(self.n_components) - - # set sampling frequency to have 1 component per time point - info = cp.deepcopy(info) - with info._unlock(): - info["sfreq"] = 1.0 - # create an evoked - filters = EvokedArray(self.filters_.T, info, tmin=0) - # the call plot_topomap - fig = filters.plot_topomap( - times=components, + spf = get_spatial_filter_from_estimator(self, info=info) + return spf.plot_filters( + components, ch_type=ch_type, scalings=scalings, sensors=sensors, @@ -565,12 +539,11 @@ def plot_filters( cbar_fmt=cbar_fmt, units=units, axes=axes, - time_format=name_format, + name_format=name_format, nrows=nrows, ncols=ncols, show=show, ) - return fig def _ajd_pham(X, eps=1e-6, max_iter=15): diff --git a/mne/decoding/spatial_filter.py b/mne/decoding/spatial_filter.py new file mode 100644 index 00000000000..169cca7d005 --- /dev/null +++ b/mne/decoding/spatial_filter.py @@ -0,0 +1,639 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import copy as cp + +import matplotlib.pyplot as plt +import numpy as np + +from ..defaults import _BORDER_DEFAULT, _EXTRAPOLATE_DEFAULT, _INTERPOLATION_DEFAULT +from ..evoked import EvokedArray +from ..utils import _check_option, fill_doc, verbose +from ..viz.utils import plt_show +from .base import LinearModel, _GEDTransformer, get_coef + + +def _plot_model( + model_array, + info, + components=None, + *, + evk_tmin=None, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format=None, + nrows=1, + ncols="auto", + show=True, +): + if components is None: + n_comps = model_array.shape[-2] + components = np.arange(n_comps) + kwargs = dict( + # args set here + times=components, + average=None, + proj=False, + units="AU" if units is None else units, + time_format=name_format, + # args passed from the upstream + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + nrows=nrows, + ncols=ncols, + show=show, + ) + + # set sampling frequency to have 1 component per time point + + if evk_tmin is None: + info = cp.deepcopy(info) + with info._unlock(): + info["sfreq"] = 1.0 + evk_tmin = 0 + + if model_array.ndim == 3: + n_classes = model_array.shape[0] + figs = list() + for class_idx in range(n_classes): + model_evk = EvokedArray(model_array[class_idx].T, info, tmin=evk_tmin) + fig = model_evk.plot_topomap( + axes=axes[class_idx] if axes else None, **kwargs + ) + figs.append(fig) + return figs + else: + model_evk = EvokedArray(model_array.T, info, tmin=evk_tmin) + fig = model_evk.plot_topomap(axes=axes, **kwargs) + return fig + + +def _plot_scree_per_class(evals, add_cumul_evals, axes): + component_numbers = np.arange(len(evals)) + cumul_evals = np.cumsum(evals) if add_cumul_evals else None + # plot individual eigenvalues + color_line = "cornflowerblue" + axes.set_xlabel("Component Index", fontsize=18) + axes.set_ylabel("Eigenvalue", fontsize=18) + axes.plot( + component_numbers, + evals, + color=color_line, + marker="o", + markersize=8, + ) + axes.tick_params(axis="y", labelsize=16) + axes.tick_params(axis="x", labelsize=16) + + if add_cumul_evals: + # plot cumulative eigenvalues + ax2 = axes.twinx() + ax2.grid(False) + color_line = "firebrick" + ax2.set_ylabel("Cumulative Eigenvalues", fontsize=18) + ax2.plot( + component_numbers, + cumul_evals, + color=color_line, + marker="o", + markersize=6, + ) + ax2.tick_params(axis="y", labelcolor=color_line, labelsize=16) + ax2.set_ylim(0) + + +def _plot_scree( + evals, + title="Scree plot", + add_cumul_evals=True, + axes=None, +): + evals_data = evals if evals.ndim == 2 else [evals] + n_classes = len(evals_data) + axes = [axes] if isinstance(axes, plt.Axes) else axes + if axes is not None and n_classes != len(axes): + raise ValueError(f"Received {len(axes)} axes, but expected {n_classes}") + + orig_axes = axes + figs = list() + for class_idx in range(n_classes): + fig = None + if orig_axes is None: + fig, ax = plt.subplots(figsize=(7, 4), layout="constrained") + else: + ax = axes[class_idx] + _plot_scree_per_class(evals_data[class_idx], add_cumul_evals, ax) + if fig is not None: + fig.suptitle(title, fontsize=22) + figs.append(fig) + + return figs[0] if len(figs) == 1 else figs + + +@verbose +def get_spatial_filter_from_estimator( + estimator, + info, + *, + inverse_transform=False, + step_name=None, + get_coefs=("filters_", "patterns_", "evals_"), + patterns_method=None, + verbose=None, +): + """Instantiate a :class:`mne.decoding.SpatialFilter` object. + + Creates object from the fitted generalized eigendecomposition + transformers or :class:`mne.decoding.LinearModel`. + This object can be used to visualize spatial filters, + patterns, and eigenvalues. + + Parameters + ---------- + estimator : instance of sklearn.base.BaseEstimator + Sklearn-based estimator or meta-estimator from which to initialize + spatial filter. Use ``step_name`` to select relevant transformer + from the pipeline object (works with nested names using ``__`` syntax). + info : instance of mne.Info + The measurement info object for plotting topomaps. + inverse_transform : bool + If True, returns filters and patterns after inverse transforming them with + the transformer steps of the estimator. Defaults to False. + step_name : str | None + Name of the sklearn's pipeline step to get the coefs from. + If inverse_transform is True, the inverse transformations + will be applied using transformers before this step. + If None, the last step will be used. Defaults to None. + get_coefs : tuple + The names of the coefficient attributes to retrieve, can include + ``'filters_'``, ``'patterns_'`` and ``'evals_'``. + If step is GEDTransformer, will use all. + if step is LinearModel will only use ``'filters_'`` and ``'patterns_'``. + Defaults to (``'filters_'``, ``'patterns_'``, ``'evals_'``). + patterns_method : str + The method used to compute the patterns. Can be None, ``'pinv'`` or ``'haufe'``. + It will be set automatically to ``'pinv'`` if step is GEDTransformer, + or to ``'haufe'`` if step is LinearModel. Defaults to None. + %(verbose)s + + Returns + ------- + sp_filter : instance of mne.decoding.SpatialFilter + The spatial filter object. + + See Also + -------- + SpatialFilter, mne.decoding.LinearModel, mne.decoding.CSP, + mne.decoding.SSD, mne.decoding.XdawnTransformer, mne.decoding.SPoC + + Notes + ----- + .. versionadded:: 1.11 + """ + for coef in get_coefs: + if coef not in ("filters_", "patterns_", "evals_"): + raise ValueError( + f"'get_coefs' can only include 'filters_', " + f"'patterns_' and 'evals_', but got {coef}." + ) + if step_name is not None: + model = estimator.get_params()[step_name] + elif hasattr(estimator, "named_steps"): + model = estimator[-1] + else: + model = estimator + if isinstance(model, LinearModel): + patterns_method = "haufe" + get_coefs = ["filters_", "patterns_"] + elif isinstance(model, _GEDTransformer): + patterns_method = "pinv" + get_coefs = ["filters_", "patterns_", "evals_"] + + coefs = { + coef[:-1]: get_coef( + estimator, + coef, + inverse_transform=False if coef == "evals_" else inverse_transform, + step_name=step_name, + verbose=verbose, + ) + for coef in get_coefs + } + + sp_filter = SpatialFilter(info, patterns_method=patterns_method, **coefs) + return sp_filter + + +class SpatialFilter: + r"""Container for spatial filter weights (evecs) and patterns. + + .. warning:: For MNE-Python decoding classes, this container should be + instantiated with `mne.decoding.get_spatial_filter_from_estimator`. + Direct instantiation with external spatial filters is possible + at your own risk. + + This object is obtained either by generalized eigendecomposition (GED) algorithms + such as :class:`mne.decoding.CSP`, :class:`mne.decoding.SPoC`, + :class:`mne.decoding.SSD`, :class:`mne.decoding.XdawnTransformer` or by + :class:`mne.decoding.LinearModel`, wrapping linear models like SVM or Logit. + The object stores the filters that projects sensor data to a reduced component + space, and the corresponding patterns (obtained by pseudoinverse in GED case or + Haufe's trick in case of :class:`mne.decoding.LinearModel`). It can also be directly + initialized using filters from other transformers (e.g. PyRiemann), + but make sure that the dimensions match. + + Parameters + ---------- + info : instance of Info + The measurement info containing channel topography. + filters : ndarray, shape ((n_classes), n_components, n_channels) + The spatial filters (transposed eigenvectors of the decomposition). + evals : ndarray, shape ((n_classes), n_components) | None + The eigenvalues of the decomposition. Defaults to ``None``. + patterns : ndarray, shape ((n_classes), n_components, n_channels) | None + The patterns of the decomposition. If None, they will be computed + from the filters using pseudoinverse. Defaults to ``None``. + patterns_method : str + The method used to compute the patterns. Can be ``'pinv'`` or ``'haufe'``. + If ``patterns`` is None, it will be set to ``'pinv'``. Defaults to ``'pinv'``. + + Attributes + ---------- + info : instance of Info + The measurement info. + filters : ndarray, shape (n_components, n_channels) + The spatial filters (unmixing matrix). Applying these filters to the data + gives the component time series. + patterns : ndarray, shape (n_components, n_channels) + The spatial patterns (mixing matrix/forward model). + These represent the scalp topography of each component. + evals : ndarray, shape (n_components,) + The eigenvalues associated with each component. + patterns_method : str + The method used to compute the patterns from the filters. + + See Also + -------- + get_spatial_filter_from_estimator, mne.decoding.LinearModel, mne.decoding.CSP, + mne.decoding.SSD, mne.decoding.XdawnTransformer, mne.decoding.SPoC + + Notes + ----- + The spatial filters and patterns are stored with shape + ``(n_components, n_channels)``. + + Filters and patterns are related by the following equation: + + .. math:: + \mathbf{A} = \mathbf{W}^{-1} + + where :math:`\mathbf{A}` is the matrix of patterns (the mixing matrix) and + :math:`\mathbf{W}` is the matrix of filters (the unmixing matrix). + + For a detailed discussion on the difference between filters and patterns for GED + see :footcite:`Cohen2022` and for linear models in + general see :footcite:`HaufeEtAl2014`. + + .. versionadded:: 1.11 + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + info, + filters, + *, + evals=None, + patterns=None, + patterns_method="pinv", + ): + _check_option( + "patterns_method", + patterns_method, + ("pinv", "haufe"), + ) + self.info = info + self.evals = evals + self.filters = filters + n_comps, n_chs = self.filters.shape[-2:] + if patterns is None: + # XXX Using numpy's pinv here to handle 3D case seamlessly + # Perhaps mne.linalg.pinv can be improved to handle 3D also + # Then it could be changed here to be consistent with + # GEDTransformer + self.patterns = np.linalg.pinv(filters.T) + self.patterns_method = "pinv" + else: + self.patterns = patterns + self.patterns_method = patterns_method + + # In case of multi-target classification in LinearModel + # number of targets can be greater than number of channels. + if patterns_method != "haufe" and n_comps > n_chs: + raise ValueError( + "Number of components can't be greater " + "than number of channels in filters, " + "perhaps the provided matrix is transposed?" + ) + if self.filters.shape != self.patterns.shape: + raise ValueError( + f"Shape mismatch between filters and patterns." + f"Filters are {self.filters.shape}," + f"while patterns are {self.patterns.shape}" + ) + + @fill_doc + def plot_filters( + self, + components=None, + tmin=None, + *, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="Filter%01d", + nrows=1, + ncols="auto", + show=True, + ): + """Plot topographic maps of model filters. + + Parameters + ---------- + components : float | array of float | 'auto' | None + Indices of filters to plot. If "auto", the number of + ``axes`` determines the amount of filters. + If None, all filters will be plotted. Defaults to None. + tmin : float | None + In case filters are distributed temporally, + this can be used to align them with times + and frequency. Use ``epochs.tmin``, for example. + Defaults to None. + %(ch_type_topomap)s + %(scalings_topomap)s + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_evoked_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap_psd)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap_evoked)s + %(axes_evoked_plot_topomap)s + name_format : str + String format for topomap values. Defaults to ``'Filter%%01d'``. + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + fig = _plot_model( + self.filters, + self.info, + components=components, + evk_tmin=tmin, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + name_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + return fig + + @fill_doc + def plot_patterns( + self, + components=None, + tmin=None, + *, + ch_type=None, + scalings=None, + sensors=True, + show_names=False, + mask=None, + mask_params=None, + contours=6, + outlines="head", + sphere=None, + image_interp=_INTERPOLATION_DEFAULT, + extrapolate=_EXTRAPOLATE_DEFAULT, + border=_BORDER_DEFAULT, + res=64, + size=1, + cmap="RdBu_r", + vlim=(None, None), + cnorm=None, + colorbar=True, + cbar_fmt="%3.1f", + units=None, + axes=None, + name_format="Pattern%01d", + nrows=1, + ncols="auto", + show=True, + ): + """Plot topographic maps of model patterns. + + Parameters + ---------- + components : float | array of float | 'auto' | None + Indices of patterns to plot. If "auto", the number of + ``axes`` determines the amount of patterns. + If None, all patterns will be plotted. Defaults to None. + tmin : float | None + In case patterns are distributed temporally, + this can be used to align them with times + and frequency. Use ``epochs.tmin``, for example. + Defaults to None. + %(ch_type_topomap)s + %(scalings_topomap)s + %(sensors_topomap)s + %(show_names_topomap)s + %(mask_evoked_topomap)s + %(mask_params_topomap)s + %(contours_topomap)s + %(outlines_topomap)s + %(sphere_topomap_auto)s + %(image_interp_topomap)s + %(extrapolate_topomap)s + %(border_topomap)s + %(res_topomap)s + %(size_topomap)s + %(cmap_topomap)s + %(vlim_plot_topomap_psd)s + %(cnorm)s + %(colorbar_topomap)s + %(cbar_fmt_topomap)s + %(units_topomap_evoked)s + %(axes_evoked_plot_topomap)s + name_format : str + String format for topomap values. Defaults to ``'Pattern%%01d'``. + %(nrows_ncols_topomap)s + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + fig = _plot_model( + self.patterns, + self.info, + components=components, + evk_tmin=tmin, + ch_type=ch_type, + scalings=scalings, + sensors=sensors, + show_names=show_names, + mask=mask, + mask_params=mask_params, + contours=contours, + outlines=outlines, + sphere=sphere, + image_interp=image_interp, + extrapolate=extrapolate, + border=border, + res=res, + size=size, + cmap=cmap, + vlim=vlim, + cnorm=cnorm, + colorbar=colorbar, + cbar_fmt=cbar_fmt, + units=units, + axes=axes, + name_format=name_format, + nrows=nrows, + ncols=ncols, + show=show, + ) + return fig + + @fill_doc + def plot_scree( + self, + title="Scree plot", + add_cumul_evals=False, + axes=None, + show=True, + ): + """Plot scree for GED eigenvalues. + + Parameters + ---------- + title : str + Title for the plot. Defaults to ``'Scree plot'``. + add_cumul_evals : bool + Whether to add second line and y-axis for cumulative eigenvalues. + Defaults to ``True``. + axes : instance of Axes | None + The matplotlib axes to plot to. Defaults to ``None``. + %(show)s + + Returns + ------- + fig : instance of matplotlib.figure.Figure + The figure. + """ + if self.evals is None: + raise AttributeError("Can't plot scree if eigenvalues are not provided.") + + fig = _plot_scree( + self.evals, + title=title, + add_cumul_evals=add_cumul_evals, + axes=axes, + ) + plt_show(show, block=False) + return fig diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index a41b3246ed2..68623876222 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -28,6 +28,7 @@ is_classifier, is_regressor, ) +from sklearn.decomposition import PCA from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge from sklearn.model_selection import ( @@ -241,35 +242,133 @@ def transform(self, X): @pytest.mark.parametrize("inverse", (True, False)) @pytest.mark.parametrize( - "Scale, kwargs", + "clf", [ - (Scaler, dict(info=None, scalings="mean")), - (_Noop, dict()), + pytest.param( + make_pipeline( + Scaler(info=None, scalings="mean"), + SlidingEstimator(make_pipeline(LinearModel(Ridge()))), + ), + id="Scaler+SlidingEstimator", + ), + pytest.param( + make_pipeline( + _Noop(), + SlidingEstimator(make_pipeline(LinearModel(Ridge()))), + ), + id="Noop+SlidingEstimator", + ), + pytest.param( + SlidingEstimator(make_pipeline(StandardScaler(), LinearModel(Ridge()))), + id="SlidingEstimator+nested StandardScaler", + ), ], ) -def test_get_coef_inverse_transform(inverse, Scale, kwargs): +def test_get_coef_inverse_transform(inverse, clf): """Test get_coef with and without inverse_transform.""" - lm_regression = LinearModel(Ridge()) X, y, A = _make_data(n_samples=1000, n_features=3, n_targets=1) - # Check with search_light and combination of preprocessing ending with sl: - # slider = SlidingEstimator(make_pipeline(StandardScaler(), lm_regression)) - # XXX : line above should work but does not as only last step is - # used in get_coef ... - slider = SlidingEstimator(make_pipeline(lm_regression)) X = np.transpose([X, -X], [1, 2, 0]) # invert X across 2 time samples - clf = make_pipeline(Scale(**kwargs), slider) clf.fit(X, y) patterns = get_coef(clf, "patterns_", inverse) filters = get_coef(clf, "filters_", inverse) assert_array_equal(filters.shape, patterns.shape, X.shape[1:]) # the two time samples get inverted patterns assert_equal(patterns[0, 0], -patterns[0, 1]) + for t in [0, 1]: - filters_t = get_coef( - clf.named_steps["slidingestimator"].estimators_[t], "filters_", False + if hasattr(clf, "named_steps"): + est_t = clf.named_steps["slidingestimator"].estimators_[t] + filters_t = get_coef(est_t, "filters_", inverse) + if inverse: + filters_t = clf[0].inverse_transform(filters_t.reshape(1, -1))[0] + else: + est_t = clf.estimators_[t] + filters_t = get_coef(est_t, "filters_", inverse) + + assert_equal(filters_t, filters[:, t]) + + +def test_get_coef_inverse_step_name(): + """Test get_coef with inverse_transform=True and a specific step_name.""" + X, y, _ = _make_data(n_samples=100, n_features=5, n_targets=1) + + # Test with a simple pipeline + pipe = make_pipeline(StandardScaler(), PCA(n_components=3), LinearModel(Ridge())) + pipe.fit(X, y) + + coef_inv_actual = get_coef( + pipe, attr="patterns_", inverse_transform=True, step_name="linearmodel" + ) + # Reshape your data using array.reshape(1, -1) if it contains a single sample. + coef_raw = pipe.named_steps["linearmodel"].patterns_.reshape(1, -1) + coef_inv_desired = pipe.named_steps["pca"].inverse_transform(coef_raw) + coef_inv_desired = pipe.named_steps["standardscaler"].inverse_transform( + coef_inv_desired + ) + + assert coef_inv_actual.shape == (X.shape[1],) + # Reshape your data using array.reshape(1, -1) if it contains a single sample. + assert_array_almost_equal(coef_inv_actual.reshape(1, -1), coef_inv_desired) + + with pytest.raises(ValueError, match="inverse_transform"): + _ = get_coef( + pipe[-1], # LinearModel + "filters_", + inverse_transform=True, + ) + with pytest.raises(ValueError, match="step_name"): + _ = get_coef( + SlidingEstimator(pipe), + "filters_", + inverse_transform=True, + step_name="slidingestimator__pipeline__linearmodel", + ) + + # Test with a nested pipeline to check __ parsing + inner_pipe = make_pipeline(PCA(n_components=3), LinearModel(Ridge())) + nested_pipe = make_pipeline(StandardScaler(), inner_pipe) + nested_pipe.fit(X, y) + coef_nested_inv_actual = get_coef( + nested_pipe, + attr="patterns_", + inverse_transform=True, + step_name="pipeline__linearmodel", + ) + linearmodel = nested_pipe.named_steps["pipeline"].named_steps["linearmodel"] + pca = nested_pipe.named_steps["pipeline"].named_steps["pca"] + scaler = nested_pipe.named_steps["standardscaler"] + + coef_nested_raw = linearmodel.patterns_.reshape(1, -1) + coef_nested_inv_desired = pca.inverse_transform(coef_nested_raw) + coef_nested_inv_desired = scaler.inverse_transform(coef_nested_inv_desired) + + assert coef_nested_inv_actual.shape == (X.shape[1],) + assert_array_almost_equal( + coef_nested_inv_actual.reshape(1, -1), coef_nested_inv_desired + ) + + with pytest.raises(ValueError, match="i_do_not_exist"): + get_coef( + pipe, attr="patterns_", inverse_transform=True, step_name="i_do_not_exist" + ) + + class NonInvertibleTransformer(BaseEstimator, TransformerMixin): + def fit(self, X, y=None): + return self + + def transform(self, X): + # In a real scenario, this would modify X + return X + + pipe = make_pipeline(NonInvertibleTransformer(), LinearModel(Ridge())) + pipe.fit(X, y) + with pytest.warns(RuntimeWarning, match="not invertible"): + _ = get_coef( + pipe, + "filters_", + inverse_transform=True, + step_name="linearmodel", ) - if Scale is _Noop: - assert_array_equal(filters_t, filters[:, t]) @pytest.mark.parametrize("n_features", [1, 5]) @@ -311,7 +410,15 @@ def test_get_coef_multiclass(n_features, n_targets): if n_features > 1 and n_targets > 1: assert_allclose(A, lm.patterns_.T, atol=2e-2) coef = get_coef(clf, "patterns_", inverse_transform=True) - lm_patterns_ = lm.patterns_[..., np.newaxis] + + lm_patterns_ = lm.patterns_ + # Expected shape is (n_targets, n_features) + # which is equivalent to (n_components, n_channels) + # in spatial filters + if lm_patterns_.ndim == 1: + lm_patterns_ = lm_patterns_[np.newaxis, :] + else: + lm_patterns_ = lm_patterns_[..., np.newaxis] assert_allclose(lm_patterns_, coef, atol=1e-5) # Check can pass fitting parameters diff --git a/mne/decoding/tests/test_spatial_filter.py b/mne/decoding/tests/test_spatial_filter.py new file mode 100644 index 00000000000..385b73fc053 --- /dev/null +++ b/mne/decoding/tests/test_spatial_filter.py @@ -0,0 +1,190 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +pytest.importorskip("sklearn") + +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +from mne import Epochs, create_info, io, pick_types, read_events +from mne.decoding import ( + CSP, + LinearModel, + SpatialFilter, + Vectorizer, + XdawnTransformer, + get_spatial_filter_from_estimator, +) + +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_dir / "test_raw.fif" +event_name = data_dir / "test-eve.fif" +tmin, tmax = -0.1, 0.2 +event_id = dict(aud_l=1, vis_l=3) +start, stop = 0, 8 + + +def _get_X_y(event_id, return_info=False): + raw = io.read_raw(raw_fname, preload=False) + events = read_events(event_name) + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) + picks = picks[2:12:3] # subselect channels -> disable proj! + raw.add_proj([], remove_existing=True) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) + X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) + y = epochs.events[:, -1] + if return_info: + return X, y, epochs.info + return X, y + + +def test_spatial_filter_init(): + """Test the initialization of the SpatialFilter class.""" + # Test initialization and factory function + rng = np.random.RandomState(0) + n, n_features = 20, 3 + X = rng.rand(n, n_features) + n_targets = 5 + y = rng.rand(n, n_targets) + clf = LinearModel(LinearRegression()) + clf.fit(X, y) + + # test get_spatial_filter_from_estimator for LinearModel + info = create_info(n_features, 1000.0, "eeg") + sp_filter = get_spatial_filter_from_estimator(clf, info) + assert sp_filter.patterns_method == "haufe" + assert_array_equal(sp_filter.filters, clf.filters_) + assert_array_equal(sp_filter.patterns, clf.patterns_) + assert sp_filter.evals is None + + with pytest.raises(ValueError, match="can only include"): + _ = get_spatial_filter_from_estimator( + clf, info, get_coefs=("foo", "foo", "foo") + ) + + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + estimator = make_pipeline(Vectorizer(), StandardScaler(), CSP(n_components=4)) + estimator.fit(X, y) + csp = estimator[-1] + # test get_spatial_filter_from_estimator for GED + sp_filter = get_spatial_filter_from_estimator(estimator, info, step_name="csp") + assert sp_filter.patterns_method == "pinv" + assert_array_equal(sp_filter.filters, csp.filters_) + assert_array_equal(sp_filter.patterns, csp.patterns_) + assert_array_equal(sp_filter.evals, csp.evals_) + assert sp_filter.info is info + + # test without step_name + sp_filter = get_spatial_filter_from_estimator(estimator, info) + assert_array_equal(sp_filter.filters, csp.filters_) + assert_array_equal(sp_filter.patterns, csp.patterns_) + assert_array_equal(sp_filter.evals, csp.evals_) + + # test basic initialization + sp_filter = SpatialFilter( + info, filters=csp.filters_, patterns=csp.patterns_, evals=csp.evals_ + ) + assert_array_equal(sp_filter.filters, csp.filters_) + assert_array_equal(sp_filter.patterns, csp.patterns_) + assert_array_equal(sp_filter.evals, csp.evals_) + assert sp_filter.info is info + + # test automatic pattern calculation via pinv + sp_filter_pinv = SpatialFilter(info, filters=csp.filters_, evals=csp.evals_) + patterns_pinv = np.linalg.pinv(csp.filters_.T) + assert_array_equal(sp_filter_pinv.patterns, patterns_pinv) + assert sp_filter_pinv.patterns_method == "pinv" + + # test shape mismatch error + with pytest.raises(ValueError, match="Shape mismatch"): + SpatialFilter(info, filters=csp.filters_, patterns=csp.patterns_[:-1]) + + # test invalid patterns_method + with pytest.raises(ValueError, match="patterns_method"): + SpatialFilter(info, filters=csp.filters_, patterns_method="foo") + + # test n_components > n_channels error + bad_filters = np.random.randn(31, 30) # 31 components, 30 channels + with pytest.raises(ValueError, match="Number of components can't be greater"): + SpatialFilter(info, filters=bad_filters) + + +def test_spatial_filter_plotting(): + """Test the plotting methods of SpatialFilter.""" + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + csp = CSP(n_components=4) + csp.fit(X, y) + + sp_filter = get_spatial_filter_from_estimator(csp, info) + + # test plot_filters + fig_filters = sp_filter.plot_filters(components=[0, 1], show=False) + assert isinstance(fig_filters, plt.Figure) + plt.close("all") + + # test plot_patterns + fig_patterns = sp_filter.plot_patterns(show=False) + assert isinstance(fig_patterns, plt.Figure) + plt.close("all") + + # test plot_scree + fig_scree = sp_filter.plot_scree(show=False, add_cumul_evals=True) + assert isinstance(fig_scree, plt.Figure) + plt.close("all") + _, axes = plt.subplots(figsize=(12, 7), layout="constrained") + fig_scree = sp_filter.plot_scree(axes=axes, show=False) + assert fig_scree == list() + plt.close("all") + + # test plot_scree raises error if evals is None + sp_filter_no_evals = SpatialFilter(info, filters=csp.filters_, evals=None) + with pytest.raises(AttributeError, match="eigenvalues are not provided"): + sp_filter_no_evals.plot_scree() + + # 3D case ('multi' GED decomposition) + n_classes = 2 + event_id = dict(aud_l=1, vis_l=3) + X, y, info = _get_X_y(event_id, return_info=True) + xdawn = XdawnTransformer(n_components=4) + xdawn.fit(X, y) + sp_filter = get_spatial_filter_from_estimator(xdawn, info) + + fig_patterns = sp_filter.plot_patterns(show=False) + assert len(fig_patterns) == n_classes + plt.close("all") + + fig_scree = sp_filter.plot_scree(show=False) + assert len(fig_scree) == n_classes + plt.close("all") + + with pytest.raises(ValueError, match="but expected"): + _, axes = plt.subplots(figsize=(12, 7), layout="constrained") + _ = sp_filter.plot_scree(axes=axes, show=False) + + _, axes = plt.subplots(n_classes, figsize=(12, 7), layout="constrained") + fig_scree = sp_filter.plot_scree(axes=axes, show=False) + assert fig_scree == list() + plt.close("all") diff --git a/tutorials/machine-learning/50_decoding.py b/tutorials/machine-learning/50_decoding.py index c2a56ce0555..30f1204598d 100644 --- a/tutorials/machine-learning/50_decoding.py +++ b/tutorials/machine-learning/50_decoding.py @@ -47,6 +47,7 @@ Vectorizer, cross_val_multiscore, get_coef, + get_spatial_filter_from_estimator, ) data_path = sample.data_path() @@ -285,12 +286,16 @@ # This is also called the mixing matrix. The example :ref:`ex-linear-patterns` # discusses the difference between patterns and filters. # -# These can be plotted with: +# These can be plotted for every spatial filter including CSP, XdawnTransformer, +# SSD and SPoC: -# Fit CSP on full data and plot +# Fit CSP on full data, plot eigenvalues sorted based on mutual information, +# and plot patterns and filters for the three components largest components. csp.fit(X, y) -csp.plot_patterns(epochs.info) -csp.plot_filters(epochs.info, scalings=1e-9) +spf = get_spatial_filter_from_estimator(csp, info=epochs.info) +spf.plot_scree() +spf.plot_patterns(components=[0, 1, 2]) +spf.plot_filters(components=[0, 1, 2], scalings=1e-9) # %% # Decoding over time