Skip to content

Commit 97fd038

Browse files
authored
FIX: Fix grand_average to support BaseSpectrum objects (#13375)
1 parent 17c4e03 commit 97fd038

File tree

4 files changed

+42
-5
lines changed

4 files changed

+42
-5
lines changed

doc/changes/dev/13375.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug with :func:`mne.grand_average` not working with :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_.

mne/channels/channels.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ def equalize_channels(instances, copy=True, verbose=None):
120120
----------
121121
instances : list
122122
A list of MNE-Python objects to equalize the channels for. Objects can
123-
be of type Raw, Epochs, Evoked, AverageTFR, Forward, Covariance,
123+
be of type Raw, Epochs, Evoked, Spectrum, AverageTFR, Forward, Covariance,
124124
CrossSpectralDensity or Info.
125+
126+
.. versionchanged:: 1.11
127+
Added support for :class:`mne.time_frequency.Spectrum` objects.
125128
copy : bool
126129
When dropping and/or re-ordering channels, an object will be copied
127130
when this parameter is set to ``True``. When set to ``False`` (the
@@ -148,21 +151,24 @@ def equalize_channels(instances, copy=True, verbose=None):
148151
from ..forward import Forward
149152
from ..io import BaseRaw
150153
from ..time_frequency import BaseTFR, CrossSpectralDensity
154+
from ..time_frequency.spectrum import BaseSpectrum
151155

152156
# Instances need to have a `ch_names` attribute and a `pick_channels`
153157
# method that supports `ordered=True`.
154158
allowed_types = (
155159
BaseRaw,
156160
BaseEpochs,
157161
Evoked,
162+
BaseSpectrum,
158163
BaseTFR,
159164
Forward,
160165
Covariance,
161166
CrossSpectralDensity,
162167
Info,
163168
)
164169
allowed_types_str = (
165-
"Raw, Epochs, Evoked, TFR, Forward, Covariance, CrossSpectralDensity or Info"
170+
"Raw, Epochs, Evoked, Spectrum, TFR, Forward, Covariance, CrossSpectralDensity "
171+
"or Info"
166172
)
167173
for inst in instances:
168174
_validate_type(

mne/time_frequency/spectrum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,10 @@ def method(self):
503503
def nave(self):
504504
return self._nave
505505

506+
@nave.setter
507+
def nave(self, nave):
508+
self._nave = nave
509+
506510
@property
507511
def weights(self):
508512
return self._weights

mne/time_frequency/tests/test_spectrum.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@
1010
from matplotlib.colors import same_color
1111
from numpy.testing import assert_allclose, assert_array_equal
1212

13-
from mne import Annotations, BaseEpochs, create_info, make_fixed_length_epochs
13+
from mne import (
14+
Annotations,
15+
BaseEpochs,
16+
create_info,
17+
grand_average,
18+
make_fixed_length_epochs,
19+
)
20+
from mne.channels import equalize_channels
1421
from mne.io import RawArray
1522
from mne.time_frequency import read_spectrum
1623
from mne.time_frequency.multitaper import _psd_from_mt
@@ -200,8 +207,8 @@ def test_combine_spectrum(raw_spectrum, weights):
200207
spectrum1 = raw_spectrum.copy()
201208
spectrum2 = raw_spectrum.copy()
202209
if weights == "nave":
203-
spectrum1._nave = 1
204-
spectrum2._nave = 2
210+
spectrum1.nave = 1
211+
spectrum2.nave = 2
205212
spectrum2._data *= 2
206213
new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights)
207214
assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3))
@@ -243,6 +250,25 @@ def test_combine_spectrum_error_catch(raw_spectrum):
243250
combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal")
244251

245252

253+
def test_grand_average(raw_spectrum):
254+
"""Test `grand_average()` works for instances of `BaseSpectrum`."""
255+
spectrum1 = raw_spectrum.copy()
256+
spectrum2 = raw_spectrum.copy()
257+
spectrum2._data *= 2
258+
new_spectrum = grand_average([spectrum1, spectrum2])
259+
assert_allclose(new_spectrum.data, spectrum1.data * 1.5)
260+
261+
262+
def test_equalize_channels(raw_spectrum):
263+
"""Test equalization of channels for instances of `BaseSpectrum`."""
264+
spect1 = raw_spectrum.copy()
265+
spect2 = spect1.copy().pick(["MEG 0122", "MEG 0111"])
266+
spect1, spect2 = equalize_channels([spect1, spect2])
267+
268+
assert spect1.ch_names == ["MEG 0111", "MEG 0122"]
269+
assert spect2.ch_names == ["MEG 0111", "MEG 0122"]
270+
271+
246272
def test_spectrum_reject_by_annot(raw):
247273
"""Test rejecting by annotation.
248274

0 commit comments

Comments
 (0)