Skip to content

Commit 1d2635f

Browse files
tsbinnsdrammocklarsoner
authored
[ENH] Add option to store and return TFR taper weights (#12910)
Co-authored-by: Daniel McCloy <dan@mccloy.info> Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
1 parent 176f64f commit 1d2635f

File tree

8 files changed

+507
-141
lines changed

8 files changed

+507
-141
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_.

mne/time_frequency/multitaper.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,7 @@ def tfr_array_multitaper(
471471
output="complex",
472472
n_jobs=None,
473473
*,
474+
return_weights=False,
474475
verbose=None,
475476
):
476477
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
@@ -504,6 +505,11 @@ def tfr_array_multitaper(
504505
coherence across trials.
505506
%(n_jobs)s
506507
The parallelization is implemented across channels.
508+
return_weights : bool, default False
509+
If True, return the taper weights. Only applies if ``output='complex'`` or
510+
``'phase'``.
511+
512+
.. versionadded:: 1.10.0
507513
%(verbose)s
508514
509515
Returns
@@ -520,6 +526,9 @@ def tfr_array_multitaper(
520526
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
521527
contain the average power and the imaginary values contain the
522528
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
529+
weights : array of shape (n_tapers, n_freqs)
530+
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
531+
``return_weights=True``.
523532
524533
See Also
525534
--------
@@ -550,6 +559,7 @@ def tfr_array_multitaper(
550559
use_fft=use_fft,
551560
decim=decim,
552561
output=output,
562+
return_weights=return_weights,
553563
n_jobs=n_jobs,
554564
verbose=verbose,
555565
)

mne/time_frequency/tests/test_tfr.py

Lines changed: 195 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -432,17 +432,21 @@ def test_tfr_morlet():
432432
def test_dpsswavelet():
433433
"""Test DPSS tapers."""
434434
freqs = np.arange(5, 25, 3)
435-
Ws = _make_dpss(
436-
1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True
435+
Ws, weights = _make_dpss(
436+
1000,
437+
freqs=freqs,
438+
n_cycles=freqs / 2.0,
439+
time_bandwidth=4.0,
440+
zero_mean=True,
441+
return_weights=True,
437442
)
438443

439-
assert len(Ws) == 3 # 3 tapers expected
444+
assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected
445+
assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs)
440446

441447
# Check that zero mean is true
442448
assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5
443449

444-
assert len(Ws[0]) == len(freqs) # As many wavelets as asked for
445-
446450

447451
@pytest.mark.slowtest
448452
def test_tfr_multitaper():
@@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
664668
with tfr.info._unlock():
665669
tfr.info["meas_date"] = want
666670
assert tfr_loaded == tfr
671+
# test with taper dimension and weights
672+
n_tapers = 3 # anything >= 1 should do
673+
weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs
674+
state = tfr.__getstate__()
675+
state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim
676+
state["weights"] = weights # add weights
677+
state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims
678+
tfr = EpochsTFR(inst=state)
679+
tfr.save(fname, overwrite=True)
680+
tfr_loaded = read_tfrs(fname)
681+
assert tfr_loaded == tfr
667682
# test overwrite
668683
with pytest.raises(OSError, match="Destination file exists."):
669684
tfr.save(fname, overwrite=False)
@@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked):
722737
AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace)
723738

724739

725-
def test_epochstfr_init_errors(epochs_tfr):
726-
"""Test __init__ for EpochsTFR."""
727-
state = epochs_tfr.__getstate__()
728-
with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"):
729-
EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0]))
740+
@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr"))
741+
def test_tfr_init_errors(inst, request, average_tfr):
742+
"""Test __init__ for {Raw,Epochs,Average}TFR."""
743+
# Load data
744+
inst = _get_inst(inst, request, average_tfr=average_tfr)
745+
state = inst.__getstate__()
746+
# Prepare for TFRArray object instantiation
747+
inst_name = inst.__class__.__name__
748+
class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR)
749+
ndims_mapping = dict(
750+
RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D")
751+
)
752+
TFR = class_mapping[inst_name]
753+
allowed_ndims = ndims_mapping[inst_name]
754+
# Check errors caught
755+
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
756+
TFR(inst=state | dict(data=inst.data[..., 0]))
757+
with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"):
758+
TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1))))
730759
with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"):
731-
EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1]))
760+
TFR(inst=state | dict(data=inst.data[..., :-1, :, :]))
732761
with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"):
733-
EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1]))
762+
TFR(inst=state | dict(times=inst.times[:-1]))
734763
with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"):
735-
EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1]))
764+
TFR(inst=state | dict(freqs=inst.freqs[:-1]))
736765

737766

738767
@pytest.mark.parametrize(
@@ -830,6 +859,25 @@ def test_plot():
830859
plt.close("all")
831860

832861

862+
@pytest.mark.parametrize("output", ("complex", "phase"))
863+
def test_plot_multitaper_complex_phase(output):
864+
"""Test TFR plotting of data with a taper dimension."""
865+
# Create example data with a taper dimension
866+
n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3)
867+
data = np.random.rand(n_chans, n_tapers, n_freqs, n_times)
868+
if output == "complex":
869+
data = data + np.random.rand(*data.shape) * 1j # add imaginary data
870+
times = np.arange(n_times)
871+
freqs = np.arange(n_freqs)
872+
weights = np.random.rand(n_tapers, n_freqs)
873+
info = mne.create_info(n_chans, 1000.0, "eeg")
874+
tfr = AverageTFRArray(
875+
info=info, data=data, times=times, freqs=freqs, weights=weights
876+
)
877+
# Check that plotting works
878+
tfr.plot()
879+
880+
833881
@pytest.mark.parametrize(
834882
"timefreqs,title,combine",
835883
(
@@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR():
11541202
):
11551203
power.average(method=np.mean)
11561204

1205+
# Check it doesn't run for taper spectra
1206+
tapered = epochs.compute_tfr(
1207+
method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex"
1208+
)
1209+
with pytest.raises(
1210+
NotImplementedError, match=r"Averaging multitaper tapers .* is not supported."
1211+
):
1212+
tapered.average()
1213+
11571214

11581215
def test_averaging_freqsandtimes_epochsTFR():
11591216
"""Test that EpochsTFR averaging freqs methods work."""
@@ -1258,12 +1315,15 @@ def test_to_data_frame():
12581315
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
12591316
n_picks = len(ch_names)
12601317
ch_types = ["eeg"] * n_picks
1318+
n_tapers = 2
12611319
n_freqs = 5
12621320
n_times = 6
1263-
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
1264-
times = np.arange(6)
1321+
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
1322+
times = np.arange(n_times)
12651323
srate = 1000.0
1266-
freqs = np.arange(5)
1324+
freqs = np.arange(n_freqs)
1325+
tapers = np.arange(n_tapers)
1326+
weights = np.ones((n_tapers, n_freqs))
12671327
events = np.zeros((n_epos, 3), dtype=int)
12681328
events[:, 0] = np.arange(n_epos)
12691329
events[:, 2] = np.arange(5, 5 + n_epos)
@@ -1276,6 +1336,7 @@ def test_to_data_frame():
12761336
freqs=freqs,
12771337
events=events,
12781338
event_id=event_id,
1339+
weights=weights,
12791340
)
12801341
# test index checking
12811342
with pytest.raises(ValueError, match="options. Valid index options are"):
@@ -1287,32 +1348,51 @@ def test_to_data_frame():
12871348
# test wide format
12881349
df_wide = tfr.to_data_frame()
12891350
assert all(np.isin(tfr.ch_names, df_wide.columns))
1290-
assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns))
1351+
assert all(
1352+
np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns)
1353+
)
12911354
# test long format
12921355
df_long = tfr.to_data_frame(long_format=True)
1293-
expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value")
1356+
expected = (
1357+
"condition",
1358+
"epoch",
1359+
"freq",
1360+
"time",
1361+
"channel",
1362+
"ch_type",
1363+
"value",
1364+
"taper",
1365+
)
12941366
assert set(expected) == set(df_long.columns)
12951367
assert set(tfr.ch_names) == set(df_long["channel"])
12961368
assert len(df_long) == tfr.data.size
12971369
# test long format w/ index
12981370
df_long = tfr.to_data_frame(long_format=True, index=["freq"])
12991371
del df_wide, df_long
13001372
# test whether data is in correct shape
1301-
df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"])
1373+
df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"])
13021374
data = tfr.data
13031375
assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze())
13041376
# compare arbitrary observation:
13051377
assert (
1306-
df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0]
1307-
== data[1, 3, 1, 2]
1378+
df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0]
1379+
== data[1, 3, 1, 1, 2]
13081380
)
13091381

13101382
# Check also for AverageTFR:
1383+
# (remove taper dimension before averaging)
1384+
state = tfr.__getstate__()
1385+
state["data"] = state["data"][:, :, 0]
1386+
state["dims"] = ("epoch", "channel", "freq", "time")
1387+
state["weights"] = None
1388+
tfr = EpochsTFR(inst=state)
13111389
tfr = tfr.average()
13121390
with pytest.raises(ValueError, match="options. Valid index options are"):
13131391
tfr.to_data_frame(index=["epoch", "condition"])
13141392
with pytest.raises(ValueError, match='"epoch" is not a valid option'):
13151393
tfr.to_data_frame(index="epoch")
1394+
with pytest.raises(ValueError, match='"taper" is not a valid option'):
1395+
tfr.to_data_frame(index="taper")
13161396
with pytest.raises(TypeError, match="index must be `None` or a string "):
13171397
tfr.to_data_frame(index=np.arange(400))
13181398
# test wide format
@@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index):
13481428
ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"]
13491429
n_picks = len(ch_names)
13501430
ch_types = ["eeg"] * n_picks
1431+
n_tapers = 2
13511432
n_freqs = 5
13521433
n_times = 6
1353-
data = np.random.rand(n_epos, n_picks, n_freqs, n_times)
1354-
times = np.arange(6)
1355-
freqs = np.arange(5)
1434+
data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times)
1435+
times = np.arange(n_times)
1436+
freqs = np.arange(n_freqs)
1437+
weights = np.ones((n_tapers, n_freqs))
13561438
events = np.zeros((n_epos, 3), dtype=int)
13571439
events[:, 0] = np.arange(n_epos)
13581440
events[:, 2] = np.arange(5, 8)
@@ -1365,14 +1447,15 @@ def test_to_data_frame_index(index):
13651447
freqs=freqs,
13661448
events=events,
13671449
event_id=event_id,
1450+
weights=weights,
13681451
)
13691452
df = tfr.to_data_frame(picks=[0, 2, 3], index=index)
13701453
# test index order/hierarchy preservation
13711454
if not isinstance(index, list):
13721455
index = [index]
13731456
assert list(df.index.names) == index
13741457
# test that non-indexed data were present as columns
1375-
non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index))
1458+
non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index))
13761459
if len(non_index):
13771460
assert all(np.isin(non_index, df.columns))
13781461

@@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
15381621
def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output):
15391622
"""Test Epochs.compute_tfr(output="complex"/"phase")."""
15401623
tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output)
1541-
assert len(tfr.shape) == 5
1624+
assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time
1625+
assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match
15421626

15431627

15441628
@pytest.mark.parametrize("copy", (False, True))
@@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
15501634
assert avgs[0].comment == str(epochs_tfr.events[0, -1])
15511635

15521636

1637+
@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked"))
1638+
def test_tfrarray_tapered_spectra(obj_type):
1639+
"""Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
1640+
# Create example data with a taper dimension
1641+
n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6)
1642+
data_shape = (n_chans, n_tapers, n_freqs, n_times)
1643+
if obj_type == "epochs":
1644+
data_shape = (n_epochs,) + data_shape
1645+
data = np.random.rand(*data_shape)
1646+
times = np.arange(n_times)
1647+
freqs = np.arange(n_freqs)
1648+
weights = np.random.rand(n_tapers, n_freqs)
1649+
info = mne.create_info(n_chans, 1000.0, "eeg")
1650+
# Prepare for TFRArray object instantiation
1651+
defaults = dict(info=info, data=data, times=times, freqs=freqs)
1652+
class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray)
1653+
TFRArray = class_mapping[obj_type]
1654+
# Check TFRArray instantiation runs with good data
1655+
TFRArray(**defaults, weights=weights)
1656+
# Check taper dimension but no weights caught
1657+
with pytest.raises(
1658+
ValueError, match="Taper dimension in data, but no weights found."
1659+
):
1660+
TFRArray(**defaults)
1661+
# Check mismatching n_taper in weights caught
1662+
with pytest.raises(
1663+
ValueError, match=r"Taper axis .* doesn't match weights attribute"
1664+
):
1665+
TFRArray(**defaults, weights=weights[:-1])
1666+
# Check mismatching n_freq in weights caught
1667+
with pytest.raises(
1668+
ValueError, match=r"Frequency axis .* doesn't match weights attribute"
1669+
):
1670+
TFRArray(**defaults, weights=weights[:, :-1])
1671+
1672+
15531673
def test_tfr_proj(epochs):
15541674
"""Test `compute_tfr(proj=True)`."""
15551675
epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True)
@@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
17311851
assert re.match(
17321852
rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title()
17331853
)
1854+
1855+
1856+
@pytest.mark.parametrize("output", ("complex", "phase"))
1857+
def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked):
1858+
"""Test plot_joint/topo/topomap() for data with a taper dimension."""
1859+
# Compute TFR with taper dimension
1860+
tfr = evoked.compute_tfr(
1861+
method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output
1862+
)
1863+
# Check that plotting works
1864+
tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed
1865+
tfr.plot_topo()
1866+
tfr.plot_topomap()
1867+
1868+
1869+
def test_combine_tfr_error_catch(average_tfr):
1870+
"""Test combine_tfr() catches errors."""
1871+
# check unrecognised weights string caught
1872+
with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'):
1873+
combine_tfr([average_tfr, average_tfr], weights="foo")
1874+
# check bad weights size caught
1875+
with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"):
1876+
combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1])
1877+
# check different channel names caught
1878+
state = average_tfr.__getstate__()
1879+
new_info = average_tfr.info.copy()
1880+
average_tfr_bad = AverageTFR(
1881+
inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"}))
1882+
)
1883+
with pytest.raises(AssertionError, match=".* do not contain the same channels"):
1884+
combine_tfr([average_tfr, average_tfr_bad])
1885+
# check different times caught
1886+
average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1))
1887+
with pytest.raises(
1888+
AssertionError, match=".* do not contain the same time instants"
1889+
):
1890+
combine_tfr([average_tfr, average_tfr_bad])
1891+
# check taper dim caught
1892+
n_tapers = 3 # anything >= 1 should do
1893+
weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs
1894+
state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1)
1895+
state["weights"] = weights
1896+
state["dims"] = ("channel", "taper", "freq", "time")
1897+
average_tfr_taper = AverageTFR(inst=state)
1898+
with pytest.raises(
1899+
NotImplementedError,
1900+
match="Aggregating multitaper tapers across TFR datasets is not supported.",
1901+
):
1902+
combine_tfr([average_tfr_taper, average_tfr_taper])

0 commit comments

Comments
 (0)