Skip to content

Commit 0474b49

Browse files
Genusterpre-commit-ci[bot]larsoner
authored
ENH: Viz for spatial filters (#13332)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson <[email protected]>
1 parent adeb865 commit 0474b49

File tree

15 files changed

+1161
-129
lines changed

15 files changed

+1161
-129
lines changed

doc/_includes/ged.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of
1414
Generalized Eigendecomposition (GED), often used in spatial filtering
1515
and source separation algorithms, such as :class:`mne.decoding.CSP`,
1616
:class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and
17-
:class:`mne.preprocessing.Xdawn`.
17+
:class:`mne.decoding.XdawnTransformer`.
1818

1919
The core principle of GED is to find a set of channel weights (spatial filter)
2020
that maximizes the ratio of signal power between two data features.

doc/api/decoding.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Decoding
3030
SPoC
3131
SSD
3232
XdawnTransformer
33+
SpatialFilter
3334

3435
Functions that assist with decoding and model fitting:
3536

@@ -39,3 +40,4 @@ Functions that assist with decoding and model fitting:
3940
compute_ems
4041
cross_val_multiscore
4142
get_coef
43+
get_spatial_filter_from_estimator

doc/changes/dev/13332.newfeature.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Implement :class:`mne.decoding.SpatialFilter` class returned by :func:`mne.decoding.get_spatial_filter_from_estimator` for
2+
visualisation of filters and patterns for :class:`mne.decoding.LinearModel`
3+
and additionally eigenvalues for GED-based transformers such as
4+
:class:`mne.decoding.XdawnTransformer`, :class:`mne.decoding.CSP`, by `Gennadiy Belonosov`_.

examples/decoding/decoding_csp_eeg.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
1515
The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`.
1616
"""
17+
1718
# Authors: Martin Billinger <[email protected]>
1819
#
1920
# License: BSD-3-Clause
@@ -30,7 +31,7 @@
3031
from mne import Epochs, pick_types
3132
from mne.channels import make_standard_montage
3233
from mne.datasets import eegbci
33-
from mne.decoding import CSP
34+
from mne.decoding import CSP, get_spatial_filter_from_estimator
3435
from mne.io import concatenate_raws, read_raw_edf
3536

3637
print(__doc__)
@@ -95,10 +96,11 @@
9596
class_balance = max(class_balance, 1.0 - class_balance)
9697
print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")
9798

98-
# plot CSP patterns estimated on full data for visualization
99+
# plot eigenvalues and patterns estimated on full data for visualization
99100
csp.fit_transform(epochs_data, labels)
100-
101-
csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)
101+
spf = get_spatial_filter_from_estimator(csp, info=epochs.info)
102+
spf.plot_scree()
103+
spf.plot_patterns(components=np.arange(4))
102104

103105
# %%
104106
# Look at performance over time

examples/decoding/decoding_spoc_CMC.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import mne
3333
from mne import Epochs
3434
from mne.datasets.fieldtrip_cmc import data_path
35-
from mne.decoding import SPoC
35+
from mne.decoding import SPoC, get_spatial_filter_from_estimator
3636

3737
# Define parameters
3838
fname = data_path() / "SubjectCMC.ds"
@@ -82,9 +82,18 @@
8282
# Plot the contributions to the detected components (i.e., the forward model)
8383

8484
spoc.fit(X, y)
85-
spoc.plot_patterns(meg_epochs.info)
85+
spf = get_spatial_filter_from_estimator(spoc, info=meg_epochs.info)
86+
spf.plot_scree()
87+
88+
# Plot patterns for the first three components
89+
# with largest absolute generalized eigenvalues,
90+
# as we can see on the scree plot
91+
spf.plot_patterns(components=[0, 1, 2])
92+
8693

8794
##############################################################################
8895
# References
8996
# ----------
9097
# .. footbibliography::
98+
99+
# %%

examples/decoding/decoding_xdawn_eeg.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Channels are concatenated and rescaled to create features vectors that will be
1111
fed into a logistic regression.
1212
"""
13+
1314
# Authors: Alexandre Barachant <[email protected]>
1415
#
1516
# License: BSD-3-Clause
@@ -26,10 +27,9 @@
2627
from sklearn.pipeline import make_pipeline
2728
from sklearn.preprocessing import MinMaxScaler
2829

29-
from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events
30+
from mne import Epochs, io, pick_types, read_events
3031
from mne.datasets import sample
31-
from mne.decoding import Vectorizer
32-
from mne.preprocessing import Xdawn
32+
from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator
3333

3434
print(__doc__)
3535

@@ -71,31 +71,33 @@
7171

7272
# Create classification pipeline
7373
clf = make_pipeline(
74-
Xdawn(n_components=n_filter),
74+
XdawnTransformer(n_components=n_filter),
7575
Vectorizer(),
7676
MinMaxScaler(),
7777
OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")),
7878
)
7979

80-
# Get the labels
81-
labels = epochs.events[:, -1]
80+
# Get the data and labels
81+
# X is of shape (n_epochs, n_channels, n_times)
82+
X = epochs.get_data(copy=False)
83+
y = epochs.events[:, -1]
8284

8385
# Cross validator
8486
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
8587

8688
# Do cross-validation
87-
preds = np.empty(len(labels))
88-
for train, test in cv.split(epochs, labels):
89-
clf.fit(epochs[train], labels[train])
90-
preds[test] = clf.predict(epochs[test])
89+
preds = np.empty(len(y))
90+
for train, test in cv.split(epochs, y):
91+
clf.fit(X[train], y[train])
92+
preds[test] = clf.predict(X[test])
9193

9294
# Classification report
9395
target_names = ["aud_l", "aud_r", "vis_l", "vis_r"]
94-
report = classification_report(labels, preds, target_names=target_names)
96+
report = classification_report(y, preds, target_names=target_names)
9597
print(report)
9698

9799
# Normalized confusion matrix
98-
cm = confusion_matrix(labels, preds)
100+
cm = confusion_matrix(y, preds)
99101
cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]
100102

101103
# Plot confusion matrix
@@ -109,30 +111,35 @@
109111
ax.set(ylabel="True label", xlabel="Predicted label")
110112

111113
# %%
112-
# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last
113-
# cross-validation fold) can be used for visualization.
114-
115-
fig, axes = plt.subplots(
116-
nrows=len(event_id),
117-
ncols=n_filter,
118-
figsize=(n_filter, len(event_id) * 2),
119-
layout="constrained",
114+
# Patterns of a fitted XdawnTransformer instance (here from the last
115+
# cross-validation fold) can be visualized using SpatialFilter container.
116+
117+
# Instantiate SpatialFilter
118+
spf = get_spatial_filter_from_estimator(
119+
clf, info=epochs.info, step_name="xdawntransformer"
120+
)
121+
122+
# Let's first examine the scree plot of generalized eigenvalues
123+
# for each class.
124+
spf.plot_scree(title="")
125+
126+
# We can see that for all four classes ~five largest components
127+
# capture most of the variance, let's plot their patterns.
128+
# Each class will now return its own figure
129+
components_to_plot = np.arange(5)
130+
figs = spf.plot_patterns(
131+
# Indices of patterns to plot,
132+
# we will plot the first three for each class
133+
components=components_to_plot,
134+
show=False, # to set the titles below
120135
)
121-
fitted_xdawn = clf.steps[0][1]
122-
info = create_info(epochs.ch_names, 1, epochs.get_channel_types())
123-
info.set_montage(epochs.get_montage())
124-
for ii, cur_class in enumerate(sorted(event_id)):
125-
cur_patterns = fitted_xdawn.patterns_[cur_class]
126-
pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0)
127-
pattern_evoked.plot_topomap(
128-
times=np.arange(n_filter),
129-
time_format="Component %d" if ii == 0 else "",
130-
colorbar=False,
131-
show_names=False,
132-
axes=axes[ii],
133-
show=False,
134-
)
135-
axes[ii, 0].set(ylabel=cur_class)
136+
137+
# Set the class titles
138+
event_id_reversed = {v: k for k, v in event_id.items()}
139+
for fig, class_idx in zip(figs, clf[0].classes_):
140+
class_name = event_id_reversed[class_idx]
141+
fig.suptitle(class_name, fontsize=16)
142+
136143

137144
# %%
138145
# References

examples/decoding/linear_model_patterns.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Note patterns/filters in MEG data are more similar than EEG data
1515
because the noise is less spatially correlated in MEG than EEG.
1616
"""
17+
1718
# Authors: Alexandre Gramfort <[email protected]>
1819
# Romain Trachel <[email protected]>
1920
# Jean-Rémi King <[email protected]>
@@ -28,11 +29,16 @@
2829
from sklearn.preprocessing import StandardScaler
2930

3031
import mne
31-
from mne import EvokedArray, io
32+
from mne import io
3233
from mne.datasets import sample
3334

3435
# import a linear classifier from mne.decoding
35-
from mne.decoding import LinearModel, Vectorizer, get_coef
36+
from mne.decoding import (
37+
LinearModel,
38+
SpatialFilter,
39+
Vectorizer,
40+
get_spatial_filter_from_estimator,
41+
)
3642

3743
print(__doc__)
3844

@@ -77,20 +83,38 @@
7783
X = scaler.fit_transform(meg_data)
7884
model.fit(X, labels)
7985

80-
# Extract and plot spatial filters and spatial patterns
86+
coefs = dict()
8187
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
8288
# We fit the linear model on Z-scored data. To make the filters
8389
# interpretable, we must reverse this normalization step
8490
coef = scaler.inverse_transform([coef])[0]
8591

8692
# The data was vectorized to fit a single model across all time points and
8793
# all channels. We thus reshape it:
88-
coef = coef.reshape(len(meg_epochs.ch_names), -1)
89-
90-
# Plot
91-
evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin)
92-
fig = evoked.plot_topomap()
93-
fig.suptitle(f"MEG {name}")
94+
coefs[name] = coef.reshape(len(meg_epochs.ch_names), -1).T
95+
96+
# Now we can instantiate the visualization container
97+
spf = SpatialFilter(info=meg_epochs.info, **coefs)
98+
fig = spf.plot_patterns(
99+
# we will automatically select patterns
100+
components="auto",
101+
# as our filters and patterns correspond to actual times
102+
# we can align them
103+
tmin=epochs.tmin,
104+
units="fT", # it's physical - we inversed the scaling
105+
show=False, # to set the title below
106+
name_format=None, # to plot actual times
107+
)
108+
fig.suptitle("MEG patterns")
109+
# Same for filters
110+
fig = spf.plot_filters(
111+
components="auto",
112+
tmin=epochs.tmin,
113+
units="fT",
114+
show=False,
115+
name_format=None,
116+
)
117+
fig.suptitle("MEG filters")
94118

95119
# %%
96120
# Let's do the same on EEG data using a scikit-learn pipeline
@@ -107,15 +131,26 @@
107131
),
108132
)
109133
clf.fit(X, y)
110-
111-
# Extract and plot patterns and filters
112-
for name in ("patterns_", "filters_"):
113-
# The `inverse_transform` parameter will call this method on any estimator
114-
# contained in the pipeline, in reverse order.
115-
coef = get_coef(clf, name, inverse_transform=True)
116-
evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin)
117-
fig = evoked.plot_topomap()
118-
fig.suptitle(f"EEG {name[:-1]}")
134+
spf = get_spatial_filter_from_estimator(
135+
clf, info=epochs.info, inverse_transform=True, step_name="linearmodel"
136+
)
137+
fig = spf.plot_patterns(
138+
components="auto",
139+
tmin=epochs.tmin,
140+
units="uV",
141+
show=False,
142+
name_format=None,
143+
)
144+
fig.suptitle("EEG patterns")
145+
# Same for filters
146+
fig = spf.plot_filters(
147+
components="auto",
148+
tmin=epochs.tmin,
149+
units="uV",
150+
show=False,
151+
name_format=None,
152+
)
153+
fig.suptitle("EEG filters")
119154

120155
# %%
121156
# References

examples/decoding/ssd_spatial_filters.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import mne
2727
from mne import Epochs
2828
from mne.datasets.fieldtrip_cmc import data_path
29-
from mne.decoding import SSD
29+
from mne.decoding import SSD, get_spatial_filter_from_estimator
3030

3131
# %%
3232
# Define parameters
@@ -70,8 +70,8 @@
7070
# (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t.
7171
# We rely on the inversion approach here.
7272

73-
pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info)
74-
pattern.plot_topomap(units=dict(mag="A.U."), time_format="")
73+
spf = get_spatial_filter_from_estimator(ssd, info=ssd.info)
74+
spf.plot_patterns(components=list(range(4)))
7575

7676
# The topographies suggest that we picked up a parietal alpha generator.
7777

@@ -150,8 +150,8 @@
150150
ssd_epochs.fit(X=epochs.get_data(copy=False))
151151

152152
# Plot topographies.
153-
pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info)
154-
pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="")
153+
spf = get_spatial_filter_from_estimator(ssd_epochs, info=ssd_epochs.info)
154+
spf.plot_patterns(components=list(range(4)))
155155
# %%
156156
# References
157157
# ----------

mne/decoding/__init__.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ __all__ = [
1111
"SSD",
1212
"Scaler",
1313
"SlidingEstimator",
14+
"SpatialFilter",
1415
"TemporalFilter",
1516
"TimeDelayingRidge",
1617
"TimeFrequency",
@@ -21,6 +22,7 @@ __all__ = [
2122
"compute_ems",
2223
"cross_val_multiscore",
2324
"get_coef",
25+
"get_spatial_filter_from_estimator",
2426
]
2527
from .base import (
2628
BaseEstimator,
@@ -33,6 +35,7 @@ from .csp import CSP, SPoC
3335
from .ems import EMS, compute_ems
3436
from .receptive_field import ReceptiveField
3537
from .search_light import GeneralizingEstimator, SlidingEstimator
38+
from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
3639
from .ssd import SSD
3740
from .time_delaying_ridge import TimeDelayingRidge
3841
from .time_frequency import TimeFrequency

0 commit comments

Comments
 (0)