Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
0765454
topomap and scree plots
Genuster Jul 14, 2025
5a7f5bd
Add SpatialFilter visualization class
Genuster Jul 15, 2025
614cd1b
Some doc-related fixes
Genuster Jul 15, 2025
c3feb72
more imrovements
Genuster Jul 18, 2025
3266d16
nest EvokedArray import
Genuster Jul 18, 2025
d720dc2
add some minor tests and fixes
Genuster Jul 18, 2025
d2fdc2d
make get_coef to work with an arbitrary named step
Genuster Jul 19, 2025
4ac29d5
clean get_coef's inverse_transform test
Genuster Jul 19, 2025
37cd069
make instantiation of the spatial filter from ged/linear model a stan…
Genuster Jul 19, 2025
7d78df6
fix tests and nesting
Genuster Jul 19, 2025
874af24
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Jul 28, 2025
f3da23a
add test for get_coef with step_name
Genuster Jul 29, 2025
842764e
add some initialziation and plotting tests
Genuster Jul 29, 2025
8f8336b
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster Jul 29, 2025
06efd97
fix docstrings
Genuster Jul 29, 2025
49dbdaa
add factory func to doc api
Genuster Jul 29, 2025
9f6b626
another docstring fix and sklearn importorskip
Genuster Jul 29, 2025
04c7d01
more get_coef tests
Genuster Jul 29, 2025
16e35b7
FIX: Timeout
larsoner Jul 29, 2025
e5d300a
Merge remote-tracking branch 'upstream/main' into pre-commit-ci-updat…
larsoner Jul 29, 2025
d355011
more spatial filter viz tests
Genuster Jul 29, 2025
395ac48
replace CSP's plots
Genuster Jul 29, 2025
4ef949c
add changelog entry
Genuster Jul 29, 2025
05450d4
Merge remote-tracking branch 'upstream/pre-commit-ci-update-config' i…
Genuster Jul 29, 2025
de0fe5d
fix changelog
Genuster Jul 29, 2025
203e985
fix axes handling
Genuster Jul 29, 2025
f3748c7
Merge branch 'main' into ged-viz
Genuster Jul 30, 2025
4839d51
Merge branch 'main' into ged-viz
Genuster Aug 5, 2025
4c9eb44
Merge branch 'main' into ged-viz
larsoner Aug 18, 2025
225dfff
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster Aug 18, 2025
ba792e2
fix code duplication
Genuster Aug 18, 2025
b95591a
Update mne/viz/decoding/ged.py
Genuster Aug 18, 2025
f41ad60
add see also and version added
Genuster Aug 18, 2025
524fdee
Fix section order and pyrefs
Genuster Aug 18, 2025
e4f1875
fix test
Genuster Aug 19, 2025
0dd8e32
move from .viz to .decoding
Genuster Aug 19, 2025
f31d480
add warning in the docstring
Genuster Aug 19, 2025
3855bc0
fix comments
Genuster Aug 19, 2025
4896aac
improve Xdawn example
Genuster Aug 19, 2025
b6e5251
improve linear model example
Genuster Aug 19, 2025
e2e1a14
improve SSD example
Genuster Aug 19, 2025
a9eeccb
update a test
Genuster Aug 19, 2025
59630d5
more updates for tutorial and examples
Genuster Aug 19, 2025
d6da2ad
fix docstring
Genuster Aug 19, 2025
ef00059
complete tests
Genuster Aug 19, 2025
066cf5f
small fix
Genuster Aug 20, 2025
9a2254f
address Eric's suggestions
Genuster Aug 21, 2025
58760da
Update doc/changes/dev/13332.newfeature.rst
Genuster Aug 21, 2025
74deb16
Merge branch 'main' into ged-viz
Genuster Aug 21, 2025
501818c
Merge remote-tracking branch 'upstream/main' into ged-viz
Genuster Aug 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/_includes/ged.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This section describes the mathematical formulation and application of
Generalized Eigendecomposition (GED), often used in spatial filtering
and source separation algorithms, such as :class:`mne.decoding.CSP`,
:class:`mne.decoding.SPoC`, :class:`mne.decoding.SSD` and
:class:`mne.preprocessing.Xdawn`.
:class:`mne.decoding.XdawnTransformer`.

The core principle of GED is to find a set of channel weights (spatial filter)
that maximizes the ratio of signal power between two data features.
Expand Down
2 changes: 2 additions & 0 deletions doc/api/decoding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Decoding
SPoC
SSD
XdawnTransformer
SpatialFilter

Functions that assist with decoding and model fitting:

Expand All @@ -39,3 +40,4 @@ Functions that assist with decoding and model fitting:
compute_ems
cross_val_multiscore
get_coef
get_spatial_filter_from_estimator
4 changes: 4 additions & 0 deletions doc/changes/dev/13332.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Implement :class:`mne.decoding.SpatialFilter` class returned by :func:`mne.decoding.get_spatial_filter_from_estimator` for
visualisation of filters and patterns for :class:`mne.decoding.LinearModel`
and additionally eigenvalues for GED-based transformers such as
:class:`mne.decoding.XdawnTransformer`, :class:`mne.decoding.CSP`, by `Gennadiy Belonosov`_.
10 changes: 6 additions & 4 deletions examples/decoding/decoding_csp_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
`PhysioNet documentation page <https://physionet.org/content/eegmmidb/1.0.0/>`_.
The dataset is available at PhysioNet :footcite:`GoldbergerEtAl2000`.
"""

# Authors: Martin Billinger <[email protected]>
#
# License: BSD-3-Clause
Expand All @@ -30,7 +31,7 @@
from mne import Epochs, pick_types
from mne.channels import make_standard_montage
from mne.datasets import eegbci
from mne.decoding import CSP
from mne.decoding import CSP, get_spatial_filter_from_estimator
from mne.io import concatenate_raws, read_raw_edf

print(__doc__)
Expand Down Expand Up @@ -95,10 +96,11 @@
class_balance = max(class_balance, 1.0 - class_balance)
print(f"Classification accuracy: {np.mean(scores)} / Chance level: {class_balance}")

# plot CSP patterns estimated on full data for visualization
# plot eigenvalues and patterns estimated on full data for visualization
csp.fit_transform(epochs_data, labels)

csp.plot_patterns(epochs.info, ch_type="eeg", units="Patterns (AU)", size=1.5)
spf = get_spatial_filter_from_estimator(csp, info=epochs.info)
spf.plot_scree()
spf.plot_patterns(components=np.arange(4))

# %%
# Look at performance over time
Expand Down
13 changes: 11 additions & 2 deletions examples/decoding/decoding_spoc_CMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import mne
from mne import Epochs
from mne.datasets.fieldtrip_cmc import data_path
from mne.decoding import SPoC
from mne.decoding import SPoC, get_spatial_filter_from_estimator

# Define parameters
fname = data_path() / "SubjectCMC.ds"
Expand Down Expand Up @@ -82,9 +82,18 @@
# Plot the contributions to the detected components (i.e., the forward model)

spoc.fit(X, y)
spoc.plot_patterns(meg_epochs.info)
spf = get_spatial_filter_from_estimator(spoc, info=meg_epochs.info)
spf.plot_scree()

# Plot patterns for the first three components
# with largest absolute generalized eigenvalues,
# as we can see on the scree plot
spf.plot_patterns(components=[0, 1, 2])


##############################################################################
# References
# ----------
# .. footbibliography::

# %%
77 changes: 42 additions & 35 deletions examples/decoding/decoding_xdawn_eeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Channels are concatenated and rescaled to create features vectors that will be
fed into a logistic regression.
"""

# Authors: Alexandre Barachant <[email protected]>
#
# License: BSD-3-Clause
Expand All @@ -26,10 +27,9 @@
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import MinMaxScaler

from mne import Epochs, EvokedArray, create_info, io, pick_types, read_events
from mne import Epochs, io, pick_types, read_events
from mne.datasets import sample
from mne.decoding import Vectorizer
from mne.preprocessing import Xdawn
from mne.decoding import Vectorizer, XdawnTransformer, get_spatial_filter_from_estimator

print(__doc__)

Expand Down Expand Up @@ -71,31 +71,33 @@

# Create classification pipeline
clf = make_pipeline(
Xdawn(n_components=n_filter),
XdawnTransformer(n_components=n_filter),
Vectorizer(),
MinMaxScaler(),
OneVsRestClassifier(LogisticRegression(penalty="l1", solver="liblinear")),
)

# Get the labels
labels = epochs.events[:, -1]
# Get the data and labels
# X is of shape (n_epochs, n_channels, n_times)
X = epochs.get_data(copy=False)
y = epochs.events[:, -1]

# Cross validator
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)

# Do cross-validation
preds = np.empty(len(labels))
for train, test in cv.split(epochs, labels):
clf.fit(epochs[train], labels[train])
preds[test] = clf.predict(epochs[test])
preds = np.empty(len(y))
for train, test in cv.split(epochs, y):
clf.fit(X[train], y[train])
preds[test] = clf.predict(X[test])

# Classification report
target_names = ["aud_l", "aud_r", "vis_l", "vis_r"]
report = classification_report(labels, preds, target_names=target_names)
report = classification_report(y, preds, target_names=target_names)
print(report)

# Normalized confusion matrix
cm = confusion_matrix(labels, preds)
cm = confusion_matrix(y, preds)
cm_normalized = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]

# Plot confusion matrix
Expand All @@ -109,30 +111,35 @@
ax.set(ylabel="True label", xlabel="Predicted label")

# %%
# The ``patterns_`` attribute of a fitted Xdawn instance (here from the last
# cross-validation fold) can be used for visualization.

fig, axes = plt.subplots(
nrows=len(event_id),
ncols=n_filter,
figsize=(n_filter, len(event_id) * 2),
layout="constrained",
# Patterns of a fitted XdawnTransformer instance (here from the last
# cross-validation fold) can be visualized using SpatialFilter container.

# Instantiate SpatialFilter
spf = get_spatial_filter_from_estimator(
clf, info=epochs.info, step_name="xdawntransformer"
)

# Let's first examine the scree plot of generalized eigenvalues
# for each class.
spf.plot_scree(title="")

# We can see that for all four classes ~five largest components
# capture most of the variance, let's plot their patterns.
# Each class will now return its own figure
components_to_plot = np.arange(5)
figs = spf.plot_patterns(
# Indices of patterns to plot,
# we will plot the first three for each class
components=components_to_plot,
show=False, # to set the titles below
)
fitted_xdawn = clf.steps[0][1]
info = create_info(epochs.ch_names, 1, epochs.get_channel_types())
info.set_montage(epochs.get_montage())
for ii, cur_class in enumerate(sorted(event_id)):
cur_patterns = fitted_xdawn.patterns_[cur_class]
pattern_evoked = EvokedArray(cur_patterns[:n_filter].T, info, tmin=0)
pattern_evoked.plot_topomap(
times=np.arange(n_filter),
time_format="Component %d" if ii == 0 else "",
colorbar=False,
show_names=False,
axes=axes[ii],
show=False,
)
axes[ii, 0].set(ylabel=cur_class)

# Set the class titles
event_id_reversed = {v: k for k, v in event_id.items()}
for fig, class_idx in zip(figs, clf[0].classes_):
class_name = event_id_reversed[class_idx]
fig.suptitle(class_name, fontsize=16)


# %%
# References
Expand Down
71 changes: 53 additions & 18 deletions examples/decoding/linear_model_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Note patterns/filters in MEG data are more similar than EEG data
because the noise is less spatially correlated in MEG than EEG.
"""

# Authors: Alexandre Gramfort <[email protected]>
# Romain Trachel <[email protected]>
# Jean-Rémi King <[email protected]>
Expand All @@ -28,11 +29,16 @@
from sklearn.preprocessing import StandardScaler

import mne
from mne import EvokedArray, io
from mne import io
from mne.datasets import sample

# import a linear classifier from mne.decoding
from mne.decoding import LinearModel, Vectorizer, get_coef
from mne.decoding import (
LinearModel,
SpatialFilter,
Vectorizer,
get_spatial_filter_from_estimator,
)

print(__doc__)

Expand Down Expand Up @@ -77,20 +83,38 @@
X = scaler.fit_transform(meg_data)
model.fit(X, labels)

# Extract and plot spatial filters and spatial patterns
coefs = dict()
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
coef = scaler.inverse_transform([coef])[0]

# The data was vectorized to fit a single model across all time points and
# all channels. We thus reshape it:
coef = coef.reshape(len(meg_epochs.ch_names), -1)

# Plot
evoked = EvokedArray(coef, meg_epochs.info, tmin=epochs.tmin)
fig = evoked.plot_topomap()
fig.suptitle(f"MEG {name}")
coefs[name] = coef.reshape(len(meg_epochs.ch_names), -1).T

# Now we can instantiate the visualization container
spf = SpatialFilter(info=meg_epochs.info, **coefs)
fig = spf.plot_patterns(
# we will automatically select patterns
components="auto",
# as our filters and patterns correspond to actual times
# we can align them
tmin=epochs.tmin,
units="fT", # it's physical - we inversed the scaling
show=False, # to set the title below
name_format=None, # to plot actual times
)
fig.suptitle("MEG patterns")
# Same for filters
fig = spf.plot_filters(
components="auto",
tmin=epochs.tmin,
units="fT",
show=False,
name_format=None,
)
fig.suptitle("MEG filters")

# %%
# Let's do the same on EEG data using a scikit-learn pipeline
Expand All @@ -107,15 +131,26 @@
),
)
clf.fit(X, y)

# Extract and plot patterns and filters
for name in ("patterns_", "filters_"):
# The `inverse_transform` parameter will call this method on any estimator
# contained in the pipeline, in reverse order.
coef = get_coef(clf, name, inverse_transform=True)
evoked = EvokedArray(coef, epochs.info, tmin=epochs.tmin)
fig = evoked.plot_topomap()
fig.suptitle(f"EEG {name[:-1]}")
spf = get_spatial_filter_from_estimator(
clf, info=epochs.info, inverse_transform=True, step_name="linearmodel"
)
fig = spf.plot_patterns(
components="auto",
tmin=epochs.tmin,
units="uV",
show=False,
name_format=None,
)
fig.suptitle("EEG patterns")
# Same for filters
fig = spf.plot_filters(
components="auto",
tmin=epochs.tmin,
units="uV",
show=False,
name_format=None,
)
fig.suptitle("EEG filters")

# %%
# References
Expand Down
10 changes: 5 additions & 5 deletions examples/decoding/ssd_spatial_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import mne
from mne import Epochs
from mne.datasets.fieldtrip_cmc import data_path
from mne.decoding import SSD
from mne.decoding import SSD, get_spatial_filter_from_estimator

# %%
# Define parameters
Expand Down Expand Up @@ -70,8 +70,8 @@
# (W^{-1}) or by multiplying the noise cov with the filters Eq. (22) (C_n W)^t.
# We rely on the inversion approach here.

pattern = mne.EvokedArray(data=ssd.patterns_[:4].T, info=ssd.info)
pattern.plot_topomap(units=dict(mag="A.U."), time_format="")
spf = get_spatial_filter_from_estimator(ssd, info=ssd.info)
spf.plot_patterns(components=list(range(4)))

# The topographies suggest that we picked up a parietal alpha generator.

Expand Down Expand Up @@ -150,8 +150,8 @@
ssd_epochs.fit(X=epochs.get_data(copy=False))

# Plot topographies.
pattern_epochs = mne.EvokedArray(data=ssd_epochs.patterns_[:4].T, info=ssd_epochs.info)
pattern_epochs.plot_topomap(units=dict(mag="A.U."), time_format="")
spf = get_spatial_filter_from_estimator(ssd_epochs, info=ssd_epochs.info)
spf.plot_patterns(components=list(range(4)))
# %%
# References
# ----------
Expand Down
3 changes: 3 additions & 0 deletions mne/decoding/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ __all__ = [
"SSD",
"Scaler",
"SlidingEstimator",
"SpatialFilter",
"TemporalFilter",
"TimeDelayingRidge",
"TimeFrequency",
Expand All @@ -21,6 +22,7 @@ __all__ = [
"compute_ems",
"cross_val_multiscore",
"get_coef",
"get_spatial_filter_from_estimator",
]
from .base import (
BaseEstimator,
Expand All @@ -33,6 +35,7 @@ from .csp import CSP, SPoC
from .ems import EMS, compute_ems
from .receptive_field import ReceptiveField
from .search_light import GeneralizingEstimator, SlidingEstimator
from .spatial_filter import SpatialFilter, get_spatial_filter_from_estimator
from .ssd import SSD
from .time_delaying_ridge import TimeDelayingRidge
from .time_frequency import TimeFrequency
Expand Down
Loading
Loading