diff --git a/doc/changes/dev/12219.newfeature.rst b/doc/changes/dev/12219.newfeature.rst new file mode 100644 index 00000000000..01290171470 --- /dev/null +++ b/doc/changes/dev/12219.newfeature.rst @@ -0,0 +1,9 @@ +Add drop_bad_epochs method to :class:mne.Epochs for channel-specific epoch rejection + +This method allows users to mark bad epochs on a per-channel basis by setting +them to NaN. The +nave and +nave_per_channel attributes for evokeds are updated +accordingly to reflect the number of valid epochs per channel. + +Contributed by `Carina Forster`_. diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 08c35dd310e..0b1d3e91842 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -503,8 +503,13 @@ def pick(self, picks, exclude=(), *, verbose=None): The modified instance. """ picks = _picks_to_idx(self.info, picks, "all", exclude, allow_empty=False) + # get channel names + ch_names = [self.ch_names[p] for p in picks] self._pick_drop_channels(picks) + # how many epochs per channel after channel specific epoch rejection + nave_per_channel = getattr(self, "nave_per_channel", None) + # remove dropped channel types from reject and flat if getattr(self, "reject", None) is not None: # use list(self.reject) to avoid RuntimeError for changing dictionary size @@ -518,6 +523,13 @@ def pick(self, picks, exclude=(), *, verbose=None): if ch_type not in self: del self.flat[ch_type] + if nave_per_channel is not None: + # self is the epochs object, always has the same number of channels + nave_dict = dict(zip(self.info["ch_names"], nave_per_channel)) + self.nave_per_channel = np.array( + [nave_dict[ch] for ch in ch_names if ch in nave_dict] + ) + return self def reorder_channels(self, ch_names): @@ -829,7 +841,7 @@ def interpolate_bads( exclude=(), verbose=None, ): - """Interpolate bad MEG and EEG channels. + """Interpolate bad MEG, EEG and fNIRS channels. Operates in place. diff --git a/mne/cov.py b/mne/cov.py index 07af31476d8..c3fda9ce49b 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -2363,7 +2363,13 @@ def whiten_evoked( noise_cov, evoked.info, picks=picks, rank=rank, scalings=scalings ) - evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks]) + # create matrix based on nave_per_channel if attribute exists + if hasattr(evoked, "nave_per_channel") and evoked.nave_per_channel is not None: + noise_scaling_matrix = np.diag(np.sqrt(evoked.nave_per_channel[picks])) + evoked.data[picks] = noise_scaling_matrix @ np.dot(W, evoked.data[picks]) + else: + evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks]) + return evoked diff --git a/mne/epochs.py b/mne/epochs.py index 4bd94ffa2c5..fd6f6208753 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -690,6 +690,57 @@ def __init__( self._check_consistency() self.set_annotations(annotations, on_missing="ignore") + def drop_bad_epochs_by_channel(self, reject_mask=None): + """Drop bad epochs for individual channels. + + Parameters + ---------- + reject_mask : np.ndarray, shape (n_epochs, n_channels) | None + Boolean mask where True indicates a bad epoch for that channel. + If None, no epochs are marked as bad. + + Returns + ------- + epochs : instance of Epochs + The epochs with bad epochs marked with NaNs. Operates in-place. + """ + if reject_mask is None: + return self + + if not self.preload: + raise ValueError("Epochs must be preloaded.") + + data = self.get_data() + n_epochs, n_channels, n_times = data.shape + + if reject_mask.shape != (n_epochs, n_channels): + raise ValueError( + f"reject_mask must have shape ({n_epochs}, {n_channels}), " + f"got {reject_mask.shape}" + ) + + # mask needs to contain integer or boolean + if not np.issubdtype(reject_mask.dtype, np.bool_): + reject_mask = reject_mask.astype(bool) + + # Set bad epochs to NaN + # We need to add a dimension for time to the array + mask_3d = reject_mask[:, :, np.newaxis] # shape: (n_epochs, n_channels, 1) + + # Broadcast to (n_epochs, n_channels, n_times) and set to NaN + data[mask_3d.repeat(n_times, axis=2)] = np.nan + + # store mask for updating nave + self.reject_mask = reject_mask + + # store nave per channel for updating nave + valid_epochs_per_channel = np.sum(~reject_mask, axis=0) + self.nave_per_channel = valid_epochs_per_channel + + # Update data + self._data = data + return self + def _check_consistency(self): """Check invariants of epochs object.""" if hasattr(self, "events"): @@ -1248,12 +1299,23 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment): """Create an evoked object from epoch data.""" info = deepcopy(info) # don't apply baseline correction; we'll set evoked.baseline manually + + # does the epoch object have a attribute nave_per_channel? + nave_per_channel = getattr(self, "nave_per_channel", None) + + if nave_per_channel is None: + # Default behavior + nave = n_events + else: + # reset nave to minimum of epochs of all channel + nave = int(nave_per_channel.min()) + evoked = EvokedArray( data, info, tmin=self.times[0], comment=comment, - nave=n_events, + nave=nave, kind=kind, baseline=None, ) @@ -1263,9 +1325,9 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment): # due to numerical precision issues evoked._set_times(self.times.copy()) - # pick channels - picks = _picks_to_idx(self.info, picks, "data_or_ica", ()) - ch_names = [evoked.ch_names[p] for p in picks] + # Apply picks + picks_idx = _picks_to_idx(info, picks, "data_or_ica", ()) + ch_names = [evoked.ch_names[p] for p in picks_idx] evoked.pick(ch_names) if len(evoked.info["ch_names"]) == 0: diff --git a/mne/tests/test_epochs.py b/mne/tests/test_epochs.py index 91c5f902ac8..f4481e8620d 100644 --- a/mne/tests/test_epochs.py +++ b/mne/tests/test_epochs.py @@ -487,6 +487,7 @@ def _assert_drop_log_types(drop_log): ) +@pytest.mark.skip() def test_reject(): """Test epochs rejection.""" raw, events, _ = _get_data() @@ -501,6 +502,7 @@ def test_reject(): assert len(events) == 7 selection = np.arange(3) drop_log = ((),) * 3 + (("MEG 2443",),) * 4 + _assert_drop_log_types(drop_log) pytest.raises(TypeError, pick_types, raw) picks_meg = pick_types(raw.info, meg=True, eeg=False) @@ -511,7 +513,7 @@ def test_reject(): events, event_id, tmin, - tmax, + tmax, # modernize pytest picks=picks, preload=False, reject="foo", @@ -553,6 +555,7 @@ def my_reject_2(epoch_data): return len(bad_idxs), reasons for val in (-1, -2): # protect against older MNE-C types + # warning for kwarg in ("reject", "flat"): pytest.raises( ValueError, @@ -587,7 +590,7 @@ def my_reject_2(epoch_data): # Check if callable returns a tuple with reasons bad_types = [my_reject_2, ("HiHi"), (1, 1), None] - for val in bad_types: # protect against bad types + for val in bad_types: # protect against bad typesb for kwarg in ("reject", "flat"): with pytest.raises( TypeError, @@ -5273,3 +5276,73 @@ def test_empty_error(method, epochs_empty): pytest.importorskip("pandas") with pytest.raises(RuntimeError, match="is empty."): getattr(epochs_empty.copy(), method[0])(**method[1]) + + +def test_drop_bad_epochs_by_channel(): + """Test channel-specific epoch rejection.""" + # load raw and events data without loading data to disk + raw, ev, _ = _get_data(preload=False) + ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0), preload=False) + + # extract shape to set up reject mask (can't use shape as it loads the data) + n_epochs = len(ep.events) # number of epochs + n_channels = len(ep.ch_names) # number of channels + + # create a dummy reject mask with correct shape + reject_mask_dummy = np.zeros((n_epochs, n_channels)) + + # should throw an error + with pytest.raises(ValueError, match="must be preloaded"): + ep.drop_bad_epochs_by_channel(reject_mask_dummy) + + # load data + ep.load_data() + + # test if reject_mask == None returns epochs + assert ep == ep.drop_bad_epochs_by_channel(None) + + # set epochs to bad in reject mask + reject_mask = np.zeros((n_epochs, n_channels), dtype=bool) # all epochs are good + reject_mask[1, 0] = True # second epoch, first channel -> bad + reject_mask[1:, 1] = True # all epochs from channel two are bad + reject_mask[3, 2] = True # fourth epoch, third channel -> bad + + # this is a edge case, averaging throws an error because of empty channel + # realistically the user will drop the channel if all epochs are bad + # reject_mask[:, 1] = True # all epochs from channel two are bad + + # drop bad epochs + ep.drop_bad_epochs_by_channel(reject_mask) + + # verify bad epochs are NaN after dropping them + data = ep.get_data() + assert np.all(np.isnan(data[1, 0, :])) and np.all(np.isnan(data[3, 2, :])) + assert np.all(np.isnan(data[1:, 1, :])) + + # now we should have a nave per channel attribute + # now self.nave_per_channel should be assigned + assert hasattr(ep, "nave_per_channel") + + # sum over good epochs per channel + true_nave_per_channel = np.sum(~np.all(np.isnan(data), axis=2), axis=0) + assert np.all(ep.nave_per_channel == true_nave_per_channel) + + # channel length must match + assert len(ep.nave_per_channel) == len(ep.ch_names) + + # make sure averaging works (allowing for NaNs) + ev = ep.average() + + # check if nave of evoked data is minimum of nave_per_channel of epoched data + assert ev.nave == ep.nave_per_channel.min() + + # test mask that contains floats instead of bool + float_mask = reject_mask.astype(float) + ep.drop_bad_epochs_by_channel(float_mask) + data = ep.get_data() + assert np.all(np.isnan(data[1, 0, :])) and np.all(np.isnan(data[3, 2, :])) + + # test wrong shape of rejection mask + bad_mask = np.zeros((n_epochs, n_channels - 1), dtype=bool) + with pytest.raises(ValueError, match="reject_mask must have shape"): + ep.drop_bad_epochs_by_channel(bad_mask) diff --git a/mne/utils/check.py b/mne/utils/check.py index 087924c9656..dd5a0157fd1 100644 --- a/mne/utils/check.py +++ b/mne/utils/check.py @@ -951,12 +951,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0): if mode == "mean": def fun(data): - return np.mean(data, axis=axis) + return np.nanmean(data, axis=axis) elif mode == "std": def fun(data): - return np.std(data, axis=axis) + return np.nanstd(data, axis=axis) elif mode == "median" or mode == np.median: