@@ -432,17 +432,21 @@ def test_tfr_morlet():
432432def test_dpsswavelet ():
433433 """Test DPSS tapers."""
434434 freqs = np .arange (5 , 25 , 3 )
435- Ws = _make_dpss (
436- 1000 , freqs = freqs , n_cycles = freqs / 2.0 , time_bandwidth = 4.0 , zero_mean = True
435+ Ws , weights = _make_dpss (
436+ 1000 ,
437+ freqs = freqs ,
438+ n_cycles = freqs / 2.0 ,
439+ time_bandwidth = 4.0 ,
440+ zero_mean = True ,
441+ return_weights = True ,
437442 )
438443
439- assert len (Ws ) == 3 # 3 tapers expected
444+ assert np .shape (Ws )[:2 ] == (3 , len (freqs )) # 3 tapers expected
445+ assert np .shape (Ws )[:2 ] == np .shape (weights ) # weights of shape (tapers, freqs)
440446
441447 # Check that zero mean is true
442448 assert np .abs (np .mean (np .real (Ws [0 ][0 ]))) < 1e-5
443449
444- assert len (Ws [0 ]) == len (freqs ) # As many wavelets as asked for
445-
446450
447451@pytest .mark .slowtest
448452def test_tfr_multitaper ():
@@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path):
664668 with tfr .info ._unlock ():
665669 tfr .info ["meas_date" ] = want
666670 assert tfr_loaded == tfr
671+ # test with taper dimension and weights
672+ n_tapers = 3 # anything >= 1 should do
673+ weights = np .ones ((n_tapers , tfr .shape [2 ])) # tapers x freqs
674+ state = tfr .__getstate__ ()
675+ state ["data" ] = np .repeat (np .expand_dims (tfr .data , 2 ), n_tapers , axis = 2 ) # add dim
676+ state ["weights" ] = weights # add weights
677+ state ["dims" ] = ("epoch" , "channel" , "taper" , "freq" , "time" ) # update dims
678+ tfr = EpochsTFR (inst = state )
679+ tfr .save (fname , overwrite = True )
680+ tfr_loaded = read_tfrs (fname )
681+ assert tfr_loaded == tfr
667682 # test overwrite
668683 with pytest .raises (OSError , match = "Destination file exists." ):
669684 tfr .save (fname , overwrite = False )
@@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked):
722737 AverageTFR (inst = full_evoked , method = "stockwell" , freqs = freqs_linspace )
723738
724739
725- def test_epochstfr_init_errors (epochs_tfr ):
726- """Test __init__ for EpochsTFR."""
727- state = epochs_tfr .__getstate__ ()
728- with pytest .raises (ValueError , match = "EpochsTFR data should be 4D, got 3" ):
729- EpochsTFR (inst = state | dict (data = epochs_tfr .data [..., 0 ]))
740+ @pytest .mark .parametrize ("inst" , ("raw_tfr" , "epochs_tfr" , "average_tfr" ))
741+ def test_tfr_init_errors (inst , request , average_tfr ):
742+ """Test __init__ for {Raw,Epochs,Average}TFR."""
743+ # Load data
744+ inst = _get_inst (inst , request , average_tfr = average_tfr )
745+ state = inst .__getstate__ ()
746+ # Prepare for TFRArray object instantiation
747+ inst_name = inst .__class__ .__name__
748+ class_mapping = dict (RawTFR = RawTFR , EpochsTFR = EpochsTFR , AverageTFR = AverageTFR )
749+ ndims_mapping = dict (
750+ RawTFR = ("3D or 4D" ), EpochsTFR = ("4D or 5D" ), AverageTFR = ("3D or 4D" )
751+ )
752+ TFR = class_mapping [inst_name ]
753+ allowed_ndims = ndims_mapping [inst_name ]
754+ # Check errors caught
755+ with pytest .raises (ValueError , match = f".*TFR data should be { allowed_ndims } " ):
756+ TFR (inst = state | dict (data = inst .data [..., 0 ]))
757+ with pytest .raises (ValueError , match = f".*TFR data should be { allowed_ndims } " ):
758+ TFR (inst = state | dict (data = np .expand_dims (inst .data , axis = (0 , 1 ))))
730759 with pytest .raises (ValueError , match = "Channel axis of data .* doesn't match info" ):
731- EpochsTFR (inst = state | dict (data = epochs_tfr .data [: , :- 1 ]))
760+ TFR (inst = state | dict (data = inst .data [... , :- 1 , :, : ]))
732761 with pytest .raises (ValueError , match = "Time axis of data.*doesn't match times attr" ):
733- EpochsTFR (inst = state | dict (times = epochs_tfr .times [:- 1 ]))
762+ TFR (inst = state | dict (times = inst .times [:- 1 ]))
734763 with pytest .raises (ValueError , match = "Frequency axis of.*doesn't match freqs attr" ):
735- EpochsTFR (inst = state | dict (freqs = epochs_tfr .freqs [:- 1 ]))
764+ TFR (inst = state | dict (freqs = inst .freqs [:- 1 ]))
736765
737766
738767@pytest .mark .parametrize (
@@ -830,6 +859,25 @@ def test_plot():
830859 plt .close ("all" )
831860
832861
862+ @pytest .mark .parametrize ("output" , ("complex" , "phase" ))
863+ def test_plot_multitaper_complex_phase (output ):
864+ """Test TFR plotting of data with a taper dimension."""
865+ # Create example data with a taper dimension
866+ n_chans , n_tapers , n_freqs , n_times = (3 , 4 , 2 , 3 )
867+ data = np .random .rand (n_chans , n_tapers , n_freqs , n_times )
868+ if output == "complex" :
869+ data = data + np .random .rand (* data .shape ) * 1j # add imaginary data
870+ times = np .arange (n_times )
871+ freqs = np .arange (n_freqs )
872+ weights = np .random .rand (n_tapers , n_freqs )
873+ info = mne .create_info (n_chans , 1000.0 , "eeg" )
874+ tfr = AverageTFRArray (
875+ info = info , data = data , times = times , freqs = freqs , weights = weights
876+ )
877+ # Check that plotting works
878+ tfr .plot ()
879+
880+
833881@pytest .mark .parametrize (
834882 "timefreqs,title,combine" ,
835883 (
@@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR():
11541202 ):
11551203 power .average (method = np .mean )
11561204
1205+ # Check it doesn't run for taper spectra
1206+ tapered = epochs .compute_tfr (
1207+ method = "multitaper" , freqs = freqs , n_cycles = n_cycles , output = "complex"
1208+ )
1209+ with pytest .raises (
1210+ NotImplementedError , match = r"Averaging multitaper tapers .* is not supported."
1211+ ):
1212+ tapered .average ()
1213+
11571214
11581215def test_averaging_freqsandtimes_epochsTFR ():
11591216 """Test that EpochsTFR averaging freqs methods work."""
@@ -1258,12 +1315,15 @@ def test_to_data_frame():
12581315 ch_names = ["EEG 001" , "EEG 002" , "EEG 003" , "EEG 004" ]
12591316 n_picks = len (ch_names )
12601317 ch_types = ["eeg" ] * n_picks
1318+ n_tapers = 2
12611319 n_freqs = 5
12621320 n_times = 6
1263- data = np .random .rand (n_epos , n_picks , n_freqs , n_times )
1264- times = np .arange (6 )
1321+ data = np .random .rand (n_epos , n_picks , n_tapers , n_freqs , n_times )
1322+ times = np .arange (n_times )
12651323 srate = 1000.0
1266- freqs = np .arange (5 )
1324+ freqs = np .arange (n_freqs )
1325+ tapers = np .arange (n_tapers )
1326+ weights = np .ones ((n_tapers , n_freqs ))
12671327 events = np .zeros ((n_epos , 3 ), dtype = int )
12681328 events [:, 0 ] = np .arange (n_epos )
12691329 events [:, 2 ] = np .arange (5 , 5 + n_epos )
@@ -1276,6 +1336,7 @@ def test_to_data_frame():
12761336 freqs = freqs ,
12771337 events = events ,
12781338 event_id = event_id ,
1339+ weights = weights ,
12791340 )
12801341 # test index checking
12811342 with pytest .raises (ValueError , match = "options. Valid index options are" ):
@@ -1287,32 +1348,51 @@ def test_to_data_frame():
12871348 # test wide format
12881349 df_wide = tfr .to_data_frame ()
12891350 assert all (np .isin (tfr .ch_names , df_wide .columns ))
1290- assert all (np .isin (["time" , "condition" , "freq" , "epoch" ], df_wide .columns ))
1351+ assert all (
1352+ np .isin (["time" , "condition" , "freq" , "epoch" , "taper" ], df_wide .columns )
1353+ )
12911354 # test long format
12921355 df_long = tfr .to_data_frame (long_format = True )
1293- expected = ("condition" , "epoch" , "freq" , "time" , "channel" , "ch_type" , "value" )
1356+ expected = (
1357+ "condition" ,
1358+ "epoch" ,
1359+ "freq" ,
1360+ "time" ,
1361+ "channel" ,
1362+ "ch_type" ,
1363+ "value" ,
1364+ "taper" ,
1365+ )
12941366 assert set (expected ) == set (df_long .columns )
12951367 assert set (tfr .ch_names ) == set (df_long ["channel" ])
12961368 assert len (df_long ) == tfr .data .size
12971369 # test long format w/ index
12981370 df_long = tfr .to_data_frame (long_format = True , index = ["freq" ])
12991371 del df_wide , df_long
13001372 # test whether data is in correct shape
1301- df = tfr .to_data_frame (index = ["condition" , "epoch" , "freq" , "time" ])
1373+ df = tfr .to_data_frame (index = ["condition" , "epoch" , "taper" , " freq" , "time" ])
13021374 data = tfr .data
13031375 assert_array_equal (df .values [:, 0 ], data [:, 0 , :, :].reshape (1 , - 1 ).squeeze ())
13041376 # compare arbitrary observation:
13051377 assert (
1306- df .loc [("he" , slice (None ), freqs [1 ], times [2 ]), ch_names [3 ]].iat [0 ]
1307- == data [1 , 3 , 1 , 2 ]
1378+ df .loc [("he" , slice (None ), tapers [ 1 ], freqs [1 ], times [2 ]), ch_names [3 ]].iat [0 ]
1379+ == data [1 , 3 , 1 , 1 , 2 ]
13081380 )
13091381
13101382 # Check also for AverageTFR:
1383+ # (remove taper dimension before averaging)
1384+ state = tfr .__getstate__ ()
1385+ state ["data" ] = state ["data" ][:, :, 0 ]
1386+ state ["dims" ] = ("epoch" , "channel" , "freq" , "time" )
1387+ state ["weights" ] = None
1388+ tfr = EpochsTFR (inst = state )
13111389 tfr = tfr .average ()
13121390 with pytest .raises (ValueError , match = "options. Valid index options are" ):
13131391 tfr .to_data_frame (index = ["epoch" , "condition" ])
13141392 with pytest .raises (ValueError , match = '"epoch" is not a valid option' ):
13151393 tfr .to_data_frame (index = "epoch" )
1394+ with pytest .raises (ValueError , match = '"taper" is not a valid option' ):
1395+ tfr .to_data_frame (index = "taper" )
13161396 with pytest .raises (TypeError , match = "index must be `None` or a string " ):
13171397 tfr .to_data_frame (index = np .arange (400 ))
13181398 # test wide format
@@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index):
13481428 ch_names = ["EEG 001" , "EEG 002" , "EEG 003" , "EEG 004" ]
13491429 n_picks = len (ch_names )
13501430 ch_types = ["eeg" ] * n_picks
1431+ n_tapers = 2
13511432 n_freqs = 5
13521433 n_times = 6
1353- data = np .random .rand (n_epos , n_picks , n_freqs , n_times )
1354- times = np .arange (6 )
1355- freqs = np .arange (5 )
1434+ data = np .random .rand (n_epos , n_picks , n_tapers , n_freqs , n_times )
1435+ times = np .arange (n_times )
1436+ freqs = np .arange (n_freqs )
1437+ weights = np .ones ((n_tapers , n_freqs ))
13561438 events = np .zeros ((n_epos , 3 ), dtype = int )
13571439 events [:, 0 ] = np .arange (n_epos )
13581440 events [:, 2 ] = np .arange (5 , 8 )
@@ -1365,14 +1447,15 @@ def test_to_data_frame_index(index):
13651447 freqs = freqs ,
13661448 events = events ,
13671449 event_id = event_id ,
1450+ weights = weights ,
13681451 )
13691452 df = tfr .to_data_frame (picks = [0 , 2 , 3 ], index = index )
13701453 # test index order/hierarchy preservation
13711454 if not isinstance (index , list ):
13721455 index = [index ]
13731456 assert list (df .index .names ) == index
13741457 # test that non-indexed data were present as columns
1375- non_index = list (set (["condition" , "time" , "freq" , "epoch" ]) - set (index ))
1458+ non_index = list (set (["condition" , "time" , "freq" , "taper" , " epoch" ]) - set (index ))
13761459 if len (non_index ):
13771460 assert all (np .isin (non_index , df .columns ))
13781461
@@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc):
15381621def test_epochs_compute_tfr_multitaper_complex_phase (epochs , output ):
15391622 """Test Epochs.compute_tfr(output="complex"/"phase")."""
15401623 tfr = epochs .compute_tfr ("multitaper" , freqs_linspace , output = output )
1541- assert len (tfr .shape ) == 5
1624+ assert len (tfr .shape ) == 5 # epoch x channel x taper x freq x time
1625+ assert tfr .weights .shape == tfr .shape [2 :4 ] # check weights and coeffs shapes match
15421626
15431627
15441628@pytest .mark .parametrize ("copy" , (False , True ))
@@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
15501634 assert avgs [0 ].comment == str (epochs_tfr .events [0 , - 1 ])
15511635
15521636
1637+ @pytest .mark .parametrize ("obj_type" , ("raw" , "epochs" , "evoked" ))
1638+ def test_tfrarray_tapered_spectra (obj_type ):
1639+ """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra."""
1640+ # Create example data with a taper dimension
1641+ n_epochs , n_chans , n_tapers , n_freqs , n_times = (5 , 3 , 4 , 2 , 6 )
1642+ data_shape = (n_chans , n_tapers , n_freqs , n_times )
1643+ if obj_type == "epochs" :
1644+ data_shape = (n_epochs ,) + data_shape
1645+ data = np .random .rand (* data_shape )
1646+ times = np .arange (n_times )
1647+ freqs = np .arange (n_freqs )
1648+ weights = np .random .rand (n_tapers , n_freqs )
1649+ info = mne .create_info (n_chans , 1000.0 , "eeg" )
1650+ # Prepare for TFRArray object instantiation
1651+ defaults = dict (info = info , data = data , times = times , freqs = freqs )
1652+ class_mapping = dict (raw = RawTFRArray , epochs = EpochsTFRArray , evoked = AverageTFRArray )
1653+ TFRArray = class_mapping [obj_type ]
1654+ # Check TFRArray instantiation runs with good data
1655+ TFRArray (** defaults , weights = weights )
1656+ # Check taper dimension but no weights caught
1657+ with pytest .raises (
1658+ ValueError , match = "Taper dimension in data, but no weights found."
1659+ ):
1660+ TFRArray (** defaults )
1661+ # Check mismatching n_taper in weights caught
1662+ with pytest .raises (
1663+ ValueError , match = r"Taper axis .* doesn't match weights attribute"
1664+ ):
1665+ TFRArray (** defaults , weights = weights [:- 1 ])
1666+ # Check mismatching n_freq in weights caught
1667+ with pytest .raises (
1668+ ValueError , match = r"Frequency axis .* doesn't match weights attribute"
1669+ ):
1670+ TFRArray (** defaults , weights = weights [:, :- 1 ])
1671+
1672+
15531673def test_tfr_proj (epochs ):
15541674 """Test `compute_tfr(proj=True)`."""
15551675 epochs .compute_tfr (method = "morlet" , freqs = freqs_linspace , proj = True )
@@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
17311851 assert re .match (
17321852 rf"Average over \d{{1,3}} { ch_type } channels\." , popup_fig .axes [0 ].get_title ()
17331853 )
1854+
1855+
1856+ @pytest .mark .parametrize ("output" , ("complex" , "phase" ))
1857+ def test_tfr_topo_plotting_multitaper_complex_phase (output , evoked ):
1858+ """Test plot_joint/topo/topomap() for data with a taper dimension."""
1859+ # Compute TFR with taper dimension
1860+ tfr = evoked .compute_tfr (
1861+ method = "multitaper" , freqs = freqs_linspace , n_cycles = 4 , output = output
1862+ )
1863+ # Check that plotting works
1864+ tfr .plot_joint (topomap_args = dict (res = 8 , contours = 0 , sensors = False )) # for speed
1865+ tfr .plot_topo ()
1866+ tfr .plot_topomap ()
1867+
1868+
1869+ def test_combine_tfr_error_catch (average_tfr ):
1870+ """Test combine_tfr() catches errors."""
1871+ # check unrecognised weights string caught
1872+ with pytest .raises (ValueError , match = 'Weights must be .* "nave" or "equal"' ):
1873+ combine_tfr ([average_tfr , average_tfr ], weights = "foo" )
1874+ # check bad weights size caught
1875+ with pytest .raises (ValueError , match = "Weights must be the same size as all_tfr" ):
1876+ combine_tfr ([average_tfr , average_tfr ], weights = [1 , 1 , 1 ])
1877+ # check different channel names caught
1878+ state = average_tfr .__getstate__ ()
1879+ new_info = average_tfr .info .copy ()
1880+ average_tfr_bad = AverageTFR (
1881+ inst = state | dict (info = new_info .rename_channels ({new_info .ch_names [0 ]: "foo" }))
1882+ )
1883+ with pytest .raises (AssertionError , match = ".* do not contain the same channels" ):
1884+ combine_tfr ([average_tfr , average_tfr_bad ])
1885+ # check different times caught
1886+ average_tfr_bad = AverageTFR (inst = state | dict (times = average_tfr .times + 1 ))
1887+ with pytest .raises (
1888+ AssertionError , match = ".* do not contain the same time instants"
1889+ ):
1890+ combine_tfr ([average_tfr , average_tfr_bad ])
1891+ # check taper dim caught
1892+ n_tapers = 3 # anything >= 1 should do
1893+ weights = np .ones ((n_tapers , average_tfr .shape [1 ])) # tapers x freqs
1894+ state ["data" ] = np .repeat (np .expand_dims (average_tfr .data , 1 ), n_tapers , axis = 1 )
1895+ state ["weights" ] = weights
1896+ state ["dims" ] = ("channel" , "taper" , "freq" , "time" )
1897+ average_tfr_taper = AverageTFR (inst = state )
1898+ with pytest .raises (
1899+ NotImplementedError ,
1900+ match = "Aggregating multitaper tapers across TFR datasets is not supported." ,
1901+ ):
1902+ combine_tfr ([average_tfr_taper , average_tfr_taper ])
0 commit comments