Skip to content

Commit b6446a4

Browse files
authored
Merge pull request #26 from wmvanvliet/rois
Add support for ROIs
2 parents e688f64 + 07e5c38 commit b6446a4

File tree

5 files changed

+463
-23
lines changed

5 files changed

+463
-23
lines changed

examples/plot_rsa_roi.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
"""
4+
Source-level RSA using ROI's
5+
============================
6+
7+
In this example, we use anatomical labels as Regions Of Interest (ROIs). Rather
8+
than using a searchlight, we compute DSMs for each ROI and then compute RSA
9+
with a single model DSM.
10+
11+
The dataset will be the MNE-sample dataset: a collection of 288 epochs in which
12+
the participant was presented with an auditory beep or visual stimulus to
13+
either the left or right ear or visual field.
14+
"""
15+
# sphinx_gallery_thumbnail_number=2
16+
# Import required packages
17+
import mne
18+
import mne_rsa
19+
20+
mne.set_log_level(True) # Be less verbose
21+
mne.viz.set_3d_backend('pyvista')
22+
23+
###############################################################################
24+
# We'll be using the data from the MNE-sample set. To speed up computations in
25+
# this example, we're going to use one of the sparse source spaces from the
26+
# testing set.
27+
sample_root = mne.datasets.sample.data_path(verbose=True)
28+
testing_root = mne.datasets.testing.data_path(verbose=True)
29+
sample_path = sample_root / 'MEG' / 'sample'
30+
testing_path = testing_root / 'MEG' / 'sample'
31+
subjects_dir = sample_root / 'subjects'
32+
33+
###############################################################################
34+
# Creating epochs from the continuous (raw) data. We downsample to 100 Hz to
35+
# speed up the RSA computations later on.
36+
raw = mne.io.read_raw_fif(sample_path / 'sample_audvis_filt-0-40_raw.fif')
37+
events = mne.read_events(sample_path / 'sample_audvis_filt-0-40_raw-eve.fif')
38+
event_id = {'audio/left': 1,
39+
'audio/right': 2,
40+
'visual/left': 3,
41+
'visual/right': 4}
42+
epochs = mne.Epochs(raw, events, event_id, preload=True)
43+
epochs.resample(100)
44+
45+
###############################################################################
46+
# It's important that the model DSM and the epochs are in the same order, so
47+
# that each row in the model DSM will correspond to an epoch. The model DSM
48+
# will be easier to interpret visually if the data is ordered such that all
49+
# epochs belonging to the same experimental condition are right next to
50+
# each-other, so patterns jump out. This can be achieved by first splitting the
51+
# epochs by experimental condition and then concatenating them together again.
52+
epoch_splits = [epochs[cl] for cl in ['audio/left', 'audio/right',
53+
'visual/left', 'visual/right']]
54+
epochs = mne.concatenate_epochs(epoch_splits)
55+
56+
###############################################################################
57+
# Now that the epochs are in the proper order, we can create a DSM based on the
58+
# experimental conditions. This type of DSM is referred to as a "sensitivity
59+
# DSM". Let's create a sensitivity DSM that will pick up the left auditory
60+
# response when RSA-ed against the MEG data. Since we want to capture areas
61+
# where left beeps generate a large signal, we specify that left beeps should
62+
# be similar to other left beeps. Since we do not want areas where visual
63+
# stimuli generate a large signal, we specify that beeps must be different from
64+
# visual stimuli. Furthermore, since in areas where visual stimuli generate
65+
# only a small signal, random noise will dominate, we also specify that visual
66+
# stimuli are different from other visual stimuli. Finally left and right
67+
# auditory beeps will be somewhat similar.
68+
69+
70+
def sensitivity_metric(event_id_1, event_id_2):
71+
"""Determine similarity between two epochs, given their event ids."""
72+
if event_id_1 == 1 and event_id_2 == 1:
73+
return 0 # Completely similar
74+
if event_id_1 == 2 and event_id_2 == 2:
75+
return 0.5 # Somewhat similar
76+
elif event_id_1 == 1 and event_id_2 == 2:
77+
return 0.5 # Somewhat similar
78+
elif event_id_1 == 2 and event_id_1 == 1:
79+
return 0.5 # Somewhat similar
80+
else:
81+
return 1 # Not similar at all
82+
83+
84+
model_dsm = mne_rsa.compute_dsm(epochs.events[:, 2], metric=sensitivity_metric)
85+
mne_rsa.plot_dsms(model_dsm, title='Model DSM')
86+
87+
###############################################################################
88+
# This example is going to be on source-level, so let's load the inverse
89+
# operator and apply it to obtain a cortical surface source estimate for each
90+
# epoch. To speed up the computation, we going to load an inverse operator from
91+
# the testing dataset that was created using a sparse source space with not too
92+
# many vertices.
93+
inv = mne.minimum_norm.read_inverse_operator(
94+
f'{testing_path}/sample_audvis_trunc-meg-eeg-oct-4-meg-inv.fif')
95+
epochs_stc = mne.minimum_norm.apply_inverse_epochs(epochs, inv, lambda2=0.1111)
96+
97+
###############################################################################
98+
# ROIs need to be defined as ``mne.Label`` objects. Here, we load the APARC
99+
# parcellation generated by FreeSurfer and treat each parcel as an ROI.
100+
rois = mne.read_labels_from_annot(parc='aparc', subject='sample',
101+
subjects_dir=subjects_dir)
102+
103+
###############################################################################
104+
# Performing the RSA. To save time, we don't use a searchlight over time, just
105+
# over the ROIs. The results are returned not only as a NumPy `ndarray`, but
106+
# also as an `mne.SourceEstimate` object, where each vertex beloning to the
107+
# same ROI has the same value.
108+
rsa_vals, stc = mne_rsa.rsa_stcs_rois(epochs_stc, model_dsm, inv['src'], rois,
109+
temporal_radius=None, n_jobs=1,
110+
verbose=False)
111+
112+
###############################################################################
113+
# To plot the RSA values on a brain, we can use one of MNE-RSA's own
114+
# visualization functions.
115+
brain = mne_rsa.plot_roi_map(rsa_vals, rois, subject='sample',
116+
subjects_dir=subjects_dir)
117+
brain.show_view('lateral', distance=600)

mne_rsa/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
__version__ = '0.8dev'
2-
from .source_level import rsa_stcs, dsm_stcs, rsa_nifti, dsm_nifti
2+
from .source_level import (rsa_stcs, dsm_stcs, rsa_stcs_rois, rsa_nifti,
3+
dsm_nifti)
34
from .sensor_level import rsa_evokeds, rsa_epochs, dsm_evokeds, dsm_epochs
45
from .searchlight import searchlight
56
from .rsa import rsa, rsa_gen, rsa_array
67
from .dsm import compute_dsm, compute_dsm_cv, dsm_array
7-
from .viz import plot_dsms, plot_dsms_topo
8+
from .viz import plot_dsms, plot_dsms_topo, plot_roi_map
89
from .folds import create_folds
910

1011
# This function is useful to have nearby

mne_rsa/searchlight.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Classes and functions having to do with creating searchlights."""
12
import numpy as np
23
from mne.utils import logger
34

@@ -42,13 +43,27 @@ class searchlight:
4243
your distance computations.
4344
4445
Defaults to ``None``.
45-
spatial_radius : floats | None
46-
The spatial radius of the searchlight patch in meters. All source
47-
points within this radius will belong to the searchlight patch. Set to
48-
None to only perform the searchlight over time. When this parameter is
49-
set, the ``dist`` parameter must also be specified. Defaults to
50-
``None``.
51-
temporal_radius : float | None
46+
spatial_radius : float | list of list of int | None
47+
This controls how spatial patches will be created. There are several
48+
ways to do this:
49+
50+
The first way is to specify a spatial radius in meters. In this case,
51+
the ``dist`` parameter must also be specified. This will create a
52+
searchlight where each patch contains all source points within this
53+
radius.
54+
55+
The second way is to specify a list of predefined patches. In this
56+
case, each element of the list should itself be a list of integer
57+
indexes along the spatial dimension of the data array. Each element of
58+
this list will become a separate patch using the data at the specified
59+
indices.
60+
61+
The third way is to set this to ``None``, which will disable the making
62+
of spatial patches and only perform the searchlight over time. This can
63+
be thought of as pooling everything into a single spatial patch.
64+
65+
Defaults to``None``.
66+
temporal_radius : int | None
5267
The temporal radius of the searchlight patch in samples. Set to
5368
``None`` to only perform the searchlight over sensors/source points.
5469
Defaults to ``None``.
@@ -126,21 +141,29 @@ def __init__(self, shape, dist=None, spatial_radius=None,
126141

127142
# Will we be creating spatial searchlight patches?
128143
if self.spatial_radius is not None:
129-
if self.dist is None:
130-
raise ValueError('A spatial radius was requested, but no '
131-
'distance information was specified '
132-
'(=dist parameter).')
133144
if self.series_dim is None:
134145
raise ValueError('Cannot create spatial searchlight patches: '
135146
f'the provided data shape ({shape}) has no '
136147
'spatial dimension.')
137-
if self.sel_series is None:
138-
self.sel_series = np.arange(shape[self.series_dim])
139-
140-
# Compressed Sparse Row format is optimal for our computations
141-
from scipy.sparse import issparse
142-
if issparse(self.dist):
143-
self.dist = self.dist.tocsr()
148+
# If spatial radius is a number, we will be making searchlight
149+
# patches based on distance computations. Alternatively, a list of
150+
# predefined spatial patches may be provided, and we don't need
151+
# `dist`.
152+
if type(self.spatial_radius) in [float, int]:
153+
if self.dist is None:
154+
raise ValueError('A spatial radius was requested, but no '
155+
'distance information was specified '
156+
'(=dist parameter).')
157+
# Compressed Sparse Row format is optimal for our computations
158+
from scipy.sparse import issparse
159+
if issparse(self.dist):
160+
self.dist = self.dist.tocsr()
161+
if self.sel_series is None:
162+
self.sel_series = np.arange(shape[self.series_dim])
163+
else:
164+
# Explicit spatial patches were provided
165+
if self.sel_series is None:
166+
self.sel_series = np.arange(len(self.spatial_radius))
144167

145168
# Will we be creating temporal searchlight patches?
146169
if temporal_radius is not None:
@@ -199,6 +222,7 @@ def __init__(self, shape, dist=None, spatial_radius=None,
199222
self._generator = iter([tuple(self.patch_template)])
200223

201224
def __iter__(self):
225+
"""Get an iterator over the searchlight patches."""
202226
return self
203227

204228
def __next__(self):
@@ -211,7 +235,11 @@ def _iter_spatio_temporal(self):
211235
patch = list(self.patch_template) # Copy the template
212236
for series in self.sel_series:
213237
# Compute all spatial locations in the searchligh path.
214-
spat_ind = _get_in_radius(self.dist, series, self.spatial_radius)
238+
if type(self.spatial_radius) in [float, int]:
239+
spat_ind = _get_in_radius(self.dist, series,
240+
self.spatial_radius)
241+
else:
242+
spat_ind = self.spatial_radius[series]
215243
patch[self.series_dim] = spat_ind
216244
for sample in self.time_centers:
217245
temp_ind = slice(sample - self.temporal_radius,
@@ -224,7 +252,11 @@ def _iter_spatial(self):
224252
logger.info('Creating spatial searchlight patches')
225253
patch = list(self.patch_template) # Copy the template
226254
for series in self.sel_series:
227-
spat_ind = _get_in_radius(self.dist, series, self.spatial_radius)
255+
if type(self.spatial_radius) in [float, int]:
256+
spat_ind = _get_in_radius(self.dist, series,
257+
self.spatial_radius)
258+
else:
259+
spat_ind = self.spatial_radius[series]
228260
patch[self.series_dim] = spat_ind
229261
yield tuple(patch)
230262

@@ -239,7 +271,7 @@ def _iter_temporal(self):
239271

240272
@property
241273
def shape(self):
242-
"""Number of generated patches along multiple dimensions.
274+
"""Get the number of generated patches along multiple dimensions.
243275
244276
This is useful for re-shaping the result obtained after consuming the
245277
this generator.

0 commit comments

Comments
 (0)