Skip to content

Commit 999d122

Browse files
committed
Disallow aggregating tapers in combine_tfr
1 parent aaef4b7 commit 999d122

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

mne/time_frequency/tests/test_tfr.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def test_average_tfr_init(full_evoked):
739739

740740
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
741741
def test_tfr_init_errors(inst, request, average_tfr):
742-
"""Test __init__ for Raw/Epochs/AverageTFR."""
742+
"""Test __init__ for {Raw,Epochs,Average}TFR."""
743743
# Load data
744744
inst = _get_inst(inst, request, average_tfr=average_tfr)
745745
state = inst.__getstate__()
@@ -1587,7 +1587,7 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
15871587

15881588
@pytest.mark.parametrize("inst", ("raw", "epochs", "evoked"))
15891589
def test_tfrarray_tapered_spectra(inst, evoked, request):
1590-
"""Test Raw/Epochs/AverageTFRArray instantiation with tapered spectra."""
1590+
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
15911591
# Load data object
15921592
inst = _get_inst(inst, request, evoked=evoked)
15931593
inst.pick("mag")
@@ -1802,3 +1802,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
18021802
assert re.match(
18031803
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title()
18041804
)
1805+
1806+
1807+
def test_combine_tfr_error_catch(request, average_tfr):
1808+
"""Test combine_tfr() catches errors."""
1809+
# check unrecognised weights string caught
1810+
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
1811+
combine_tfr([average_tfr, average_tfr], weights="foo")
1812+
# check bad weights size caught
1813+
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"):
1814+
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1])
1815+
# check different channel names caught
1816+
state = average_tfr.__getstate__()
1817+
new_info = average_tfr.info.copy()
1818+
average_tfr_bad = AverageTFR(
1819+
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"}))
1820+
)
1821+
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
1822+
combine_tfr([average_tfr, average_tfr_bad])
1823+
# check different times caught
1824+
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1))
1825+
with pytest.raises(
1826+
AssertionError, match=".* do not contain the same time instants"
1827+
):
1828+
combine_tfr([average_tfr, average_tfr_bad])
1829+
# check taper dim caught
1830+
n_tapers = 3 # anything >= 1 should do
1831+
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs
1832+
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1)
1833+
state["weights"] = weights
1834+
state["dims"] = ("channel", "taper", "freq", "time")
1835+
average_tfr_taper = AverageTFR(inst=state)
1836+
with pytest.raises(
1837+
NotImplementedError,
1838+
match="Aggregating multitaper tapers across TFR datasets is not supported.",
1839+
):
1840+
combine_tfr([average_tfr_taper, average_tfr_taper])

mne/time_frequency/tfr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3941,8 +3941,16 @@ def combine_tfr(all_tfr, weights="nave"):
39413941
39423942
Notes
39433943
-----
3944+
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
3945+
phase data is not supported.
3946+
39443947
.. versionadded:: 0.11.0
39453948
"""
3949+
if any("taper" in tfr._dims for tfr in all_tfr):
3950+
raise NotImplementedError(
3951+
"Aggregating multitaper tapers across TFR datasets is not supported."
3952+
)
3953+
39463954
tfr = all_tfr[0].copy()
39473955
if isinstance(weights, str):
39483956
if weights not in ("nave", "equal"):

mne/utils/numerics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True):
550550
551551
Notes
552552
-----
553+
Aggregating multitaper TFR datasets with a taper dimension such as for complex or
554+
phase data is not supported.
555+
553556
.. versionadded:: 0.11.0
554557
"""
555558
# check if all elements in the given list are evoked data

0 commit comments

Comments
 (0)