-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
ENH: Viz for spatial filters #13332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
ENH: Viz for spatial filters #13332
Changes from 46 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
0765454
topomap and scree plots
Genuster 5a7f5bd
Add SpatialFilter visualization class
Genuster 614cd1b
Some doc-related fixes
Genuster c3feb72
more imrovements
Genuster 3266d16
nest EvokedArray import
Genuster d720dc2
add some minor tests and fixes
Genuster d2fdc2d
make get_coef to work with an arbitrary named step
Genuster 4ac29d5
clean get_coef's inverse_transform test
Genuster 37cd069
make instantiation of the spatial filter from ged/linear model a stan…
Genuster 7d78df6
fix tests and nesting
Genuster 874af24
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] f3da23a
add test for get_coef with step_name
Genuster 842764e
add some initialziation and plotting tests
Genuster 8f8336b
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster 06efd97
fix docstrings
Genuster 49dbdaa
add factory func to doc api
Genuster 9f6b626
another docstring fix and sklearn importorskip
Genuster 04c7d01
more get_coef tests
Genuster 16e35b7
FIX: Timeout
larsoner e5d300a
Merge remote-tracking branch 'upstream/main' into pre-commit-ci-updat…
larsoner d355011
more spatial filter viz tests
Genuster 395ac48
replace CSP's plots
Genuster 4ef949c
add changelog entry
Genuster 05450d4
Merge remote-tracking branch 'upstream/pre-commit-ci-update-config' i…
Genuster de0fe5d
fix changelog
Genuster 203e985
fix axes handling
Genuster f3748c7
Merge branch 'main' into ged-viz
Genuster 4839d51
Merge branch 'main' into ged-viz
Genuster 4c9eb44
Merge branch 'main' into ged-viz
larsoner 225dfff
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster ba792e2
fix code duplication
Genuster b95591a
Update mne/viz/decoding/ged.py
Genuster f41ad60
add see also and version added
Genuster 524fdee
Fix section order and pyrefs
Genuster e4f1875
fix test
Genuster 0dd8e32
move from .viz to .decoding
Genuster f31d480
add warning in the docstring
Genuster 3855bc0
fix comments
Genuster 4896aac
improve Xdawn example
Genuster b6e5251
improve linear model example
Genuster e2e1a14
improve SSD example
Genuster a9eeccb
update a test
Genuster 59630d5
more updates for tutorial and examples
Genuster d6da2ad
fix docstring
Genuster ef00059
complete tests
Genuster 066cf5f
small fix
Genuster 9a2254f
address Eric's suggestions
Genuster 58760da
Update doc/changes/dev/13332.newfeature.rst
Genuster 74deb16
Merge branch 'main' into ged-viz
Genuster 501818c
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Implement :class:`mne.decoding.SpatialFilter` container class for | ||
visualisation of filters and patterns for :class:`mne.decoding.LinearModel` | ||
and additionally eigenvalues for GED-based transformers such as | ||
:class:`mne.decoding.XdawnTransformer`, :class:`mne.decoding.CSP`, by `Gennadiy Belonosov`_. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_. | ||
The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`. | ||
""" | ||
|
||
# Authors: Martin Billinger <[email protected]> | ||
# | ||
# License: BSD-3-Clause | ||
|
@@ -30,7 +31,7 @@ | |
from mne import Epochs, pick_types | ||
from mne.channels import make_standard_montage | ||
from mne.datasets import eegbci | ||
from mne.decoding import CSP | ||
from mne.decoding import CSP, get_spatial_filter_from_estimator | ||
from mne.io import concatenate_raws, read_raw_edf | ||
|
||
print(__doc__) | ||
|
@@ -95,10 +96,11 @@ | |
class_balance = max(class_balance, 1.0 - class_balance) | ||
print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}") | ||
|
||
# plot CSP patterns estimated on full data for visualization | ||
# plot eigenvalues and patterns estimated on full data for visualization | ||
csp.fit_transform(epochs_data, labels) | ||
|
||
csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5) | ||
spf = get_spatial_filter_from_estimator(csp, info=epochs.info) | ||
spf.plot_scree() | ||
spf.plot_patterns(components=np.arange(4)) | ||
|
||
# %% | ||
# Look at performance over time | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
Channels are concatenated and rescaled to create features vectors that will be | ||
fed into a logistic regression. | ||
""" | ||
|
||
# Authors: Alexandre Barachant <[email protected]> | ||
# | ||
# License: BSD-3-Clause | ||
|
@@ -26,10 +27,9 @@ | |
from sklearn.pipeline import make_pipeline | ||
from sklearn.preprocessing import MinMaxScaler | ||
|
||
from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events | ||
from mne import Epochs, io, pick_types, read_events | ||
from mne.datasets import sample | ||
from mne.decoding import Vectorizer | ||
from mne.preprocessing import Xdawn | ||
from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator | ||
|
||
print(__doc__) | ||
|
||
|
@@ -71,31 +71,33 @@ | |
|
||
# Create classification pipeline | ||
clf = make_pipeline( | ||
Xdawn(n_components=n_filter), | ||
XdawnTransformer(n_components=n_filter), | ||
Vectorizer(), | ||
MinMaxScaler(), | ||
OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")), | ||
) | ||
|
||
# Get the labels | ||
labels = epochs.events[:, -1] | ||
# Get the data and labels | ||
# X is of shape (n_epochs, n_channels, n_times) | ||
X = epochs.get_data(copy=False) | ||
y = epochs.events[:, -1] | ||
|
||
# Cross validator | ||
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) | ||
|
||
# Do cross-validation | ||
preds = np.empty(len(labels)) | ||
for train, test in cv.split(epochs, labels): | ||
clf.fit(epochs[train], labels[train]) | ||
preds[test] = clf.predict(epochs[test]) | ||
preds = np.empty(len(y)) | ||
for train, test in cv.split(epochs, y): | ||
clf.fit(X[train], y[train]) | ||
preds[test] = clf.predict(X[test]) | ||
|
||
# Classification report | ||
target_names = ["aud_l", "aud_r", "vis_l", "vis_r"] | ||
report = classification_report(labels, preds, target_names=target_names) | ||
report = classification_report(y, preds, target_names=target_names) | ||
print(report) | ||
|
||
# Normalized confusion matrix | ||
cm = confusion_matrix(labels, preds) | ||
cm = confusion_matrix(y, preds) | ||
cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis] | ||
|
||
# Plot confusion matrix | ||
|
@@ -109,30 +111,35 @@ | |
ax.set(ylabel="True label", xlabel="Predicted label") | ||
|
||
# %% | ||
# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last | ||
# cross-validation fold) can be used for visualization. | ||
|
||
fig, axes = plt.subplots( | ||
nrows=len(event_id), | ||
ncols=n_filter, | ||
figsize=(n_filter, len(event_id) * 2), | ||
layout="constrained", | ||
# Patterns of a fitted XdawnTransformer instance (here from the last | ||
# cross-validation fold) can be visualized using SpatialFilter container. | ||
|
||
# Instantiate SpatialFilter | ||
spf = get_spatial_filter_from_estimator( | ||
clf, info=epochs.info, step_name="xdawntransformer" | ||
) | ||
|
||
# Let's first examine the scree plot of generalized eigenvalues | ||
# for each class. | ||
spf.plot_scree(title="") | ||
|
||
# We can see that for all four classes ~five largest components | ||
# capture most of the variance, let's plot their patterns. | ||
# Each class will now return its own figure | ||
components_to_plot = np.arange(5) | ||
figs = spf.plot_patterns( | ||
# Indices of patterns to plot, | ||
# we will plot the first three for each class | ||
components=components_to_plot, | ||
show=False, # to set the titles below | ||
) | ||
fitted_xdawn = clf.steps[0][1] | ||
info = create_info(epochs.ch_names, 1, epochs.get_channel_types()) | ||
info.set_montage(epochs.get_montage()) | ||
for ii, cur_class in enumerate(sorted(event_id)): | ||
cur_patterns = fitted_xdawn.patterns_[cur_class] | ||
pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0) | ||
pattern_evoked.plot_topomap( | ||
times=np.arange(n_filter), | ||
time_format="Component %d" if ii == 0 else "", | ||
colorbar=False, | ||
show_names=False, | ||
axes=axes[ii], | ||
show=False, | ||
) | ||
axes[ii, 0].set(ylabel=cur_class) | ||
|
||
# Set the class titles | ||
event_id_reversed = {v: k for k, v in event_id.items()} | ||
for fig, class_idx in zip(figs, clf[0].classes_): | ||
class_name = event_id_reversed[class_idx] | ||
fig.suptitle(class_name, fontsize=16) | ||
|
||
|
||
# %% | ||
# References | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
Note patterns/filters in MEG data are more similar than EEG data | ||
because the noise is less spatially correlated in MEG than EEG. | ||
""" | ||
|
||
# Authors: Alexandre Gramfort <[email protected]> | ||
# Romain Trachel <[email protected]> | ||
# Jean-Rémi King <[email protected]> | ||
|
@@ -28,11 +29,16 @@ | |
from sklearn.preprocessing import StandardScaler | ||
|
||
import mne | ||
from mne import EvokedArray, io | ||
from mne import io | ||
from mne.datasets import sample | ||
|
||
# import a linear classifier from mne.decoding | ||
from mne.decoding import LinearModel, Vectorizer, get_coef | ||
from mne.decoding import ( | ||
LinearModel, | ||
SpatialFilter, | ||
Vectorizer, | ||
get_spatial_filter_from_estimator, | ||
) | ||
|
||
print(__doc__) | ||
|
||
|
@@ -77,20 +83,38 @@ | |
X = scaler.fit_transform(meg_data) | ||
model.fit(X, labels) | ||
|
||
# Extract and plot spatial filters and spatial patterns | ||
coefs = dict() | ||
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): | ||
# We fit the linear model on Z-scored data. To make the filters | ||
# interpretable, we must reverse this normalization step | ||
coef = scaler.inverse_transform([coef])[0] | ||
|
||
# The data was vectorized to fit a single model across all time points and | ||
# all channels. We thus reshape it: | ||
coef = coef.reshape(len(meg_epochs.ch_names), -1) | ||
|
||
# Plot | ||
evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin) | ||
fig = evoked.plot_topomap() | ||
fig.suptitle(f"MEG {name}") | ||
coefs[name] = coef.reshape(len(meg_epochs.ch_names), -1).T | ||
|
||
# Now we can instantiate the visualization container | ||
spf = SpatialFilter(info=meg_epochs.info, **coefs) | ||
fig = spf.plot_patterns( | ||
# we will automatically select patterns | ||
components="auto", | ||
# as our filters and patterns correspond to actual times | ||
# we can align them | ||
tmin=epochs.tmin, | ||
units="fT", # it's physical - we inversed the scaling | ||
show=False, # to set the title below | ||
name_format=None, # to plot actual times | ||
) | ||
fig.suptitle("MEG patterns") | ||
# Same for filters | ||
fig = spf.plot_filters( | ||
components="auto", | ||
tmin=epochs.tmin, | ||
units="fT", | ||
show=False, | ||
name_format=None, | ||
) | ||
fig.suptitle("MEG filters") | ||
|
||
# %% | ||
# Let's do the same on EEG data using a scikit-learn pipeline | ||
|
@@ -107,15 +131,26 @@ | |
), | ||
) | ||
clf.fit(X, y) | ||
|
||
# Extract and plot patterns and filters | ||
for name in ("patterns_", "filters_"): | ||
# The `inverse_transform` parameter will call this method on any estimator | ||
# contained in the pipeline, in reverse order. | ||
coef = get_coef(clf, name, inverse_transform=True) | ||
evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin) | ||
fig = evoked.plot_topomap() | ||
fig.suptitle(f"EEG {name[:-1]}") | ||
spf = get_spatial_filter_from_estimator( | ||
clf, info=epochs.info, inverse_transform=True, step_name="linearmodel" | ||
) | ||
fig = spf.plot_patterns( | ||
components="auto", | ||
tmin=epochs.tmin, | ||
units="uV", | ||
show=False, | ||
name_format=None, | ||
) | ||
fig.suptitle("EEG patterns") | ||
# Same for filters | ||
fig = spf.plot_filters( | ||
components="auto", | ||
tmin=epochs.tmin, | ||
units="uV", | ||
show=False, | ||
name_format=None, | ||
) | ||
fig.suptitle("EEG filters") | ||
|
||
# %% | ||
# References | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.