Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api/visualization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Visualization
plot_volume_source_estimates
plot_vector_source_estimates
plot_sparse_source_estimates
plot_stat_cluster
plot_tfr_topomap
plot_topo_image_epochs
plot_topomap
Expand Down
1 change: 1 addition & 0 deletions doc/changes/dev/13366.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add :func:`~mne.viz.plot_stat_cluster` that plots the spatial extent of a cluster on top of a brain by `Shristi Baral`_.
157 changes: 157 additions & 0 deletions mne/viz/_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
)
from ._dipole import _check_concat_dipoles, _plot_dipole_3d, _plot_dipole_mri_outlines
from .evoked_field import EvokedField
from .ui_events import subscribe
from .utils import (
_check_time_unit,
_get_cmap,
Expand Down Expand Up @@ -4301,3 +4302,159 @@ def _get_3d_option(key):
else:
opt = opt.lower() == "true"
return opt


def plot_stat_cluster(cluster, src, brain, time="max-extent", color="magenta", width=1):
"""Plot the spatial extent of a cluster on top of a brain.

Parameters
----------
cluster : tuple
The cluster to plot. A cluster is a tuple of two elements:
an array of time indices
and an array of vertex indices.
src : SourceSpaces
The source space that was used for the inverse computation.
brain : Brain
The brain figure on which to plot the cluster.
time : float | "interactive" | "max-extent"
The time (in seconds) at which to plot the spatial extent of the cluster.
If set to ``"interactive"`` the time will follow the selected time in the brain
figure.
By default, ``"max-extent"``, the time of maximal spatial extent is chosen.
color : str
A maplotlib-style color specification indicating the color to use when plotting
the spatial extent of the cluster.
width : int
The width of the lines used to draw the outlines.

Returns
-------
brain : Brain
The brain figure, now with the cluster plotted on top of it.
"""
# Here due to circular import
from ..label import Label

# args check
if not isinstance(cluster, tuple):
raise TypeError(f"Tuple expected, got {type(cluster)} instead.")
elif len(cluster) != 2:
raise ValueError(
"A cluster is a tuple of two elements, a list time indices "
"and list of vertex indices."
Comment on lines +4344 to +4345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"A cluster is a tuple of two elements, a list time indices "
"and list of vertex indices."
"A cluster is a tuple of two elements: an array of time indices "
"and an array of vertex indices."

)
else:
cluster_time_idx, cluster_vertex_index = cluster

# A cluster is defined both in space and time. If we want to plot the boundaries of
# the cluster in space, we must choose a specific time for which to show the
# boundaries (as they change over time).
if time == "max-extent":
time_idx, n_vertices = np.unique(cluster_time_idx, return_counts=True)
time_idx = time_idx[np.argmax(n_vertices)]
elif time == "interactive":
time_idx = brain._data["time_idx"]
elif isinstance(time, float):
time_idx = np.searchsorted(brain._times[:-1], time)
else:
raise ValueError(
"Time should be 'max-extent', 'interactive', or floating point"
f" value, got '{time}' instead."
)

# Select only the vertex indices at the chosen time
draw_vertex_index = [
v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx
]

# Create the anatomical label containing the vertex indices belonging to the
# cluster. A label cannot span both hemispheres.
# So we must filter the vertices based on their hemisphere.

# The source space object is actually a list of two source spaces, left and right
# hemisphere.
src_lh, src_rh = src

# Split the vertices based on the hemisphere in which they are located.
lh_verts, rh_verts = src_lh["vertno"], src_rh["vertno"]
n_lh_verts = len(lh_verts)
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
draw_rh_verts = [
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
]

# Vertices in a label must be unique and in increasing order
draw_lh_verts = np.unique(draw_lh_verts)
draw_rh_verts = np.unique(draw_rh_verts)

# We are now ready to create the anatomical label objects
cluster_index = 0
for label in brain.labels["lh"] + brain.labels["rh"]:
if label.name.startswith("cluster-"):
try:
cluster_index = max(cluster_index, int(label.name.split("-", 1)[1]))
except ValueError:
pass
lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")

# Transform vertex indices into proper vertex numbers.
# Not every vertex in the original high-resolution brain mesh is a
# source point in the source estimate. Do draw nice smooth curves, we need to
# interpolate the vertex indices.

# Here, we interpolate the vertices in each label to the full resolution mesh
if len(lh_label) > 0:
lh_label = lh_label.smooth(
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
)
brain.add_label(lh_label, borders=width, color=color)
if len(rh_label) > 0:
rh_label = rh_label.smooth(
smooth=3, subject=brain._subject, subjects_dir=brain._subjects_dir
)
brain.add_label(rh_label, borders=width, color=color)

def on_time_change(event):
time_idx = np.searchsorted(brain._times, event.time)
for hemi in brain._hemis:
mesh = brain._layered_meshes[hemi]
for i, label in enumerate(brain.labels[hemi]):
if label.name == f"cluster-{cluster_index}":
del brain.labels[hemi][i]
mesh.remove_overlay(label.name)

# Select only the vertex indices at the chosen time
draw_vertex_index = [
v for v, t in zip(cluster_vertex_index, cluster_time_idx) if t == time_idx
]
draw_lh_verts = [lh_verts[v] for v in draw_vertex_index if v < n_lh_verts]
draw_rh_verts = [
rh_verts[v - n_lh_verts] for v in draw_vertex_index if v >= n_lh_verts
]

# Vertices in a label must be unique and in increasing order
draw_lh_verts = np.unique(draw_lh_verts)
draw_rh_verts = np.unique(draw_rh_verts)
lh_label = Label(draw_lh_verts, hemi="lh", name=f"cluster-{cluster_index}")
rh_label = Label(draw_rh_verts, hemi="rh", name=f"cluster-{cluster_index}")
if len(lh_label) > 0:
lh_label = lh_label.smooth(
smooth=3,
subject=brain._subject,
subjects_dir=brain._subjects_dir,
verbose=False,
)
brain.add_label(lh_label, borders=width, color=color)
if len(rh_label) > 0:
rh_label = rh_label.smooth(
smooth=3,
subject=brain._subject,
subjects_dir=brain._subjects_dir,
verbose=False,
)
brain.add_label(rh_label, borders=width, color=color)

if time == "interactive":
subscribe(brain, "time_change", on_time_change)
2 changes: 2 additions & 0 deletions mne/viz/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ __all__ = [
"plot_source_estimates",
"plot_source_spectrogram",
"plot_sparse_source_estimates",
"plot_stat_cluster",
"plot_tfr_topomap",
"plot_topo_image_epochs",
"plot_topomap",
Expand All @@ -97,6 +98,7 @@ from ._3d import (
plot_head_positions,
plot_source_estimates,
plot_sparse_source_estimates,
plot_stat_cluster,
plot_vector_source_estimates,
plot_volume_source_estimates,
set_3d_options,
Expand Down
55 changes: 55 additions & 0 deletions mne/viz/tests/test_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
plot_head_positions,
plot_source_estimates,
plot_sparse_source_estimates,
plot_stat_cluster,
snapshot_brain_montage,
)
from mne.viz._3d import _get_map_ticks, _linearize_map, _process_clim
Expand Down Expand Up @@ -1413,3 +1414,57 @@ def test_link_brains(renderer_interactive):
with pytest.raises(TypeError, match="type is Brain"):
link_brains("foo")
link_brains(brain, time=True, camera=True)


@testing.requires_testing_data
def test_plot_stat_cluster(renderer_interactive):
"""Test plotting clusters on brain in static and interactive mode."""
sample_src = read_source_spaces(src_fname)
vertices = [s["vertno"] for s in sample_src]
n_time = 5
n_verts = sum(len(v) for v in vertices)

# simulate stc data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this test, I don't think it's actually needed to have STC data, we could just do:

brain = Brain("sample", subjects_dir=subjects_dir, surface="white")

stc_data = np.zeros(n_verts * n_time)
stc_size = stc_data.size
stc_data[(np.random.rand(stc_size // 20) * stc_size).astype(int)] = (
np.random.RandomState(0).rand(stc_data.size // 20)
)
stc_data.shape = (n_verts, n_time)
stc = SourceEstimate(stc_data, vertices, 1, 1)

# Simulate a cluster
cluster_time_idx = [1, 1, 2, 3]
cluster_vertex_idx = [0, 1, 2, 3]
cluster = (cluster_time_idx, cluster_vertex_idx)

brain = plot_source_estimates(
stc,
"sample",
background=(1, 1, 0),
subjects_dir=subjects_dir,
colorbar=True,
clim="auto",
)
# Test for incorrect argument in time
with pytest.raises(ValueError):
plot_stat_cluster(cluster, sample_src, brain, "foo")

# test for incorrect shape of cluster
with pytest.raises(TypeError):
plot_stat_cluster(([1]), sample_src, brain)

# test for incorrect data type of cluster
with pytest.raises(TypeError):
plot_stat_cluster([[1, 2, 3], [1, 2, 3]], sample_src, brain)

# All arguments are correct
plot_stat_cluster(cluster, sample_src, brain)

# Check that the proper anatomical label has been constructed.
assert len(brain.labels["lh"]) == 1
assert len(brain.labels["rh"]) == 0
assert brain.labels["lh"][0].name == "cluster-0"

brain.close()
del brain
34 changes: 29 additions & 5 deletions tutorials/stats-source-space/20_cluster_1samp_spatiotemporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from mne.epochs import equalize_epoch_counts
from mne.minimum_norm import apply_inverse, read_inverse_operator
from mne.stats import spatio_temporal_cluster_1samp_test, summarize_clusters_stc
from mne.viz import plot_stat_cluster

# %%
# Set parameters
Expand Down Expand Up @@ -142,19 +143,18 @@
# Read the source space we are morphing to
src = mne.read_source_spaces(src_fname)
fsave_vertices = [s["vertno"] for s in src]
morph_mat = mne.compute_source_morph(
morph = mne.compute_source_morph(
src=inverse_operator["src"],
subject_to="fsaverage",
spacing=fsave_vertices,
subjects_dir=subjects_dir,
).morph_mat

n_vertices_fsave = morph_mat.shape[0]
)
n_vertices_fsave = morph.morph_mat.shape[0]

# We have to change the shape for the dot() to work properly
X = X.reshape(n_vertices_sample, n_times * n_subjects * 2)
print("Morphing data.")
X = morph_mat.dot(X) # morph_mat is a sparse matrix
X = morph.morph_mat.dot(X) # morph_mat is a sparse matrix
X = X.reshape(n_vertices_fsave, n_times, n_subjects, 2)

# %%
Expand Down Expand Up @@ -264,3 +264,27 @@

# We could save this via the following:
# brain.save_image('clusters.png')

# %%
# Alternatively, you may wish to observe the spatial and temporal extent of
# single clusters. The code below demonstrates how to plot the cluster
# boundary on top of an existing source estimate.

difference = morph.apply(condition1 - condition2)
difference_plot = difference.plot(
hemi="both",
views="lateral",
subjects_dir=subjects_dir,
size=(800, 800),
initial_time=0.1,
)

# Plot one cluster at the time of maximal spatial extent of that cluster
plot_stat_cluster(
good_clusters[2], src, difference_plot, time="max-extent", color="magenta", width=1
)

# Plotting the cluster in interactive mode allows scrolling through time
plot_stat_cluster(
good_clusters[2], src, difference_plot, time="interactive", color="magenta", width=1
)
Loading