@@ -739,7 +739,7 @@ def test_average_tfr_init(full_evoked):
739739
740740@pytest .mark .parametrize ("inst" , ("raw_tfr" , "epochs_tfr" , "average_tfr" ))
741741def test_tfr_init_errors (inst , request , average_tfr ):
742- """Test __init__ for Raw/ Epochs/AverageTFR ."""
742+ """Test __init__ for { Raw, Epochs,Average}TFR ."""
743743 # Load data
744744 inst = _get_inst (inst , request , average_tfr = average_tfr )
745745 state = inst .__getstate__ ()
@@ -1587,7 +1587,7 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy):
15871587
15881588@pytest .mark .parametrize ("inst" , ("raw" , "epochs" , "evoked" ))
15891589def test_tfrarray_tapered_spectra (inst , evoked , request ):
1590- """Test Raw/ Epochs/AverageTFRArray instantiation with tapered spectra."""
1590+ """Test { Raw, Epochs,Average}TFRArray instantiation with tapered spectra."""
15911591 # Load data object
15921592 inst = _get_inst (inst , request , evoked = evoked )
15931593 inst .pick ("mag" )
@@ -1802,3 +1802,39 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request):
18021802 assert re .match (
18031803 rf"Average over \d{{1,3}} { ch_type } channels\." , popup_fig .axes [0 ].get_title ()
18041804 )
1805+
1806+
1807+ def test_combine_tfr_error_catch (request , average_tfr ):
1808+ """Test combine_tfr() catches errors."""
1809+ # check unrecognised weights string caught
1810+ with pytest .raises (ValueError , match = 'Weights must be .* "nave" or "equal"' ):
1811+ combine_tfr ([average_tfr , average_tfr ], weights = "foo" )
1812+ # check bad weights size caught
1813+ with pytest .raises (ValueError , match = "Weights must be the same size as all_tfr" ):
1814+ combine_tfr ([average_tfr , average_tfr ], weights = [1 , 1 , 1 ])
1815+ # check different channel names caught
1816+ state = average_tfr .__getstate__ ()
1817+ new_info = average_tfr .info .copy ()
1818+ average_tfr_bad = AverageTFR (
1819+ inst = state | dict (info = new_info .rename_channels ({new_info .ch_names [0 ]: "foo" }))
1820+ )
1821+ with pytest .raises (AssertionError , match = ".* do not contain the same channels" ):
1822+ combine_tfr ([average_tfr , average_tfr_bad ])
1823+ # check different times caught
1824+ average_tfr_bad = AverageTFR (inst = state | dict (times = average_tfr .times + 1 ))
1825+ with pytest .raises (
1826+ AssertionError , match = ".* do not contain the same time instants"
1827+ ):
1828+ combine_tfr ([average_tfr , average_tfr_bad ])
1829+ # check taper dim caught
1830+ n_tapers = 3 # anything >= 1 should do
1831+ weights = np .ones ((n_tapers , average_tfr .shape [1 ])) # tapers x freqs
1832+ state ["data" ] = np .repeat (np .expand_dims (average_tfr .data , 1 ), n_tapers , axis = 1 )
1833+ state ["weights" ] = weights
1834+ state ["dims" ] = ("channel" , "taper" , "freq" , "time" )
1835+ average_tfr_taper = AverageTFR (inst = state )
1836+ with pytest .raises (
1837+ NotImplementedError ,
1838+ match = "Aggregating multitaper tapers across TFR datasets is not supported." ,
1839+ ):
1840+ combine_tfr ([average_tfr_taper , average_tfr_taper ])
0 commit comments