Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
d7ccfd4
function to reject epochs per channel
CarinaFo Nov 16, 2023
f94c7e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2023
32a87f8
added masked data to channel specific rejection
CarinaFo Nov 16, 2023
8fccf71
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Nov 16, 2023
6e6c4a4
added masked data to function and to return
CarinaFo Nov 16, 2023
7da2ab5
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 16, 2023
300ce43
updated epochs.average with np.nanmean
CarinaFo Nov 16, 2023
afd2c8e
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Nov 16, 2023
f5804ab
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 17, 2023
880d157
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 17, 2023
afed69d
Update mne/epochs.py
CarinaFo Nov 17, 2023
194c83a
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 18, 2023
e054119
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Dec 1, 2023
f742c1e
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Dec 5, 2023
b74d4de
updated changelog for PR11776
CarinaFo Feb 7, 2024
cbd7083
updated github account name
CarinaFo Feb 7, 2024
8a3eea5
resolve merge conflict - forgot to pull before adding to changelog
CarinaFo Feb 7, 2024
eb6ac3d
added contribution to changelog
CarinaFo Feb 7, 2024
7d17f89
added second PR number to close both
CarinaFo Feb 7, 2024
c84bebb
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 13, 2024
9e00065
Merge branch 'mne-tools:main' into channel_specific_epoch_rejection
CarinaFo Feb 13, 2024
0c20484
Merge branch 'mne-tools:main' into channel_specific_epoch_rejection
CarinaFo Feb 14, 2024
0d03885
deleted bad epochs function
CarinaFo Feb 14, 2024
e16cf74
added interpolate bad epochs method
CarinaFo Feb 14, 2024
b00f2b0
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 14, 2024
0cb0ed5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 14, 2024
8a2f95b
fixed bug in interpolate epochs method
CarinaFo Feb 14, 2024
3e9435b
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 14, 2024
341fd25
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Feb 15, 2024
1e2252e
added interpolate bad epochs nan to interpolate_bads function
CarinaFo Feb 15, 2024
dfbba00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
cc58b73
removed interpolate bad epochs from interpolate_bads due to MixinClass
CarinaFo Feb 15, 2024
167f38c
removed interpolate bad epochs
CarinaFo Feb 15, 2024
4e675cc
added set bad epochs to NaN method for epochs class
CarinaFo Feb 15, 2024
3b5d67b
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 15, 2024
2868197
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
7fc2be1
removed return statement (operates in-place)
CarinaFo Feb 15, 2024
91302bb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 15, 2024
b22cf11
deleted epoch based rejection from doc string
CarinaFo Feb 15, 2024
b011e6a
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Feb 15, 2024
c794710
DW initial revisions
dominikwelke Feb 21, 2024
a27b135
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Feb 21, 2024
f2ea3c4
Merge branch 'mne-tools:main' into channel_specific_epoch_rejection
CarinaFo Feb 24, 2024
7feea14
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 26, 2024
1856286
Merge branch 'channel_specific_epoch_rejection' of github.com:CarinaF…
CarinaFo Feb 26, 2024
caac9f4
Merge pull request #3 from dominikwelke/carinafo/channel_specific_epo…
CarinaFo Feb 26, 2024
fa9567a
fixed docstring
CarinaFo Feb 28, 2024
4abf7fd
fixed dostring set_bad_epochs_to_NaN
CarinaFo Feb 28, 2024
7f707e7
Merge branch 'mne-tools:main' into channel_specific_epoch_rejection
CarinaFo Sep 4, 2025
2a5e271
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Sep 9, 2025
5d048a8
clean up channel specific epoch rejection function
CarinaFo Sep 9, 2025
6e9753e
clean up channel specific epoch rejection function, renamed function
CarinaFo Sep 9, 2025
900ea35
included nave updates
CarinaFo Sep 10, 2025
610aae8
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Sep 22, 2025
d44a602
after pre-commit hooks
CarinaFo Sep 22, 2025
e92adb7
basic testing of drop bad epochs
CarinaFo Sep 22, 2025
970a66d
changelog
CarinaFo Sep 22, 2025
ab6cef7
Update mne/channels/channels.py
CarinaFo Sep 23, 2025
b8b9984
Update doc/changes/dev/12219.newfeature.rst
CarinaFo Sep 23, 2025
a158f5d
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Sep 29, 2025
ce4a1a5
replaced lit with boolean mask
CarinaFo Sep 30, 2025
6a9d3b8
added time dimension to mask and forced array to boolean, afterpre co…
CarinaFo Sep 30, 2025
d493db1
updated basic test function
CarinaFo Sep 30, 2025
d967b00
fixed failed tests
CarinaFo Sep 30, 2025
86cb361
droppped nave attribute (should be in evoked)
CarinaFo Sep 30, 2025
174e3fa
added reject mask as attribute and nave_per_channel
CarinaFo Oct 1, 2025
ab4d295
added nave_per_channel to evoked from epoch after pre commit
CarinaFo Oct 1, 2025
9c9f643
fixed bug
CarinaFo Oct 1, 2025
d34ec0e
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Oct 2, 2025
7f47d19
fix bug in test
CarinaFo Oct 6, 2025
fdc03c0
Merge branch 'channel_specific_epoch_rejection' of https://github.com…
CarinaFo Oct 6, 2025
30ee239
fixed another bug
CarinaFo Oct 6, 2025
008aed9
added return docstring
CarinaFo Oct 6, 2025
0afca43
fix contributor name
CarinaFo Oct 6, 2025
f72066c
contributer name
CarinaFo Oct 6, 2025
a2de721
update changelog
CarinaFo Oct 11, 2025
80b4500
added baseline to evoked_from epoch_data (fix test failure)
CarinaFo Oct 11, 2025
2bd5cbb
fixed channel picks test failure
CarinaFo Oct 11, 2025
32d2f05
fixed test length of channels after averaging
CarinaFo Oct 11, 2025
819ec7a
added nave_per_channel to evoked
CarinaFo Oct 11, 2025
674fd7a
undid change to evoked
CarinaFo Oct 11, 2025
432d072
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Oct 15, 2025
69ab83c
Merge branch 'channel_specific_epoch_rejection' of https://github.com…
CarinaFo Oct 18, 2025
ace7912
updated docstring
CarinaFo Oct 18, 2025
5dd5f77
docstring (no types in inputs)
CarinaFo Oct 18, 2025
4c70c09
fix bug in evoked_from_epoch_data
CarinaFo Nov 15, 2025
fc5b343
add nave_per_channel to pick function
CarinaFo Nov 15, 2025
f723b6e
name change
CarinaFo Nov 15, 2025
316d8f5
renamed test function and fixed some bugs
CarinaFo Nov 15, 2025
30be959
revert to fix bug in other epochs test functions (baseline correction…
CarinaFo Nov 15, 2025
c3c6bb3
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 16, 2025
cc7204e
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 23, 2025
d4d3ddf
included test for None in drop_bad_channels
CarinaFo Nov 23, 2025
05377b4
increase code coverage
CarinaFo Nov 23, 2025
7490dc4
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Nov 29, 2025
ff4f912
create nave matrix
CarinaFo Jan 10, 2026
9092281
Merge branch 'main' into channel_specific_epoch_rejection
CarinaFo Jan 10, 2026
7e227a0
Merge branch 'channel_specific_epoch_rejection' of https://github.com…
CarinaFo Jan 10, 2026
5eeef54
check for attribute
CarinaFo Jan 10, 2026
d224e67
included None
CarinaFo Jan 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/changes/dev/12219.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Add drop_bad_epochs method to :class:mne.Epochs for channel-specific epoch rejection

This method allows users to mark bad epochs on a per-channel basis by setting
them to NaN. The
nave and
nave_per_channel attributes for evokeds are updated
accordingly to reflect the number of valid epochs per channel.

Contributed by `Carina Forster`_.
14 changes: 13 additions & 1 deletion mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,13 @@ def pick(self, picks, exclude=(), *, verbose=None):
The modified instance.
"""
picks = _picks_to_idx(self.info, picks, "all", exclude, allow_empty=False)
# get channel names
ch_names = [self.ch_names[p] for p in picks]
self._pick_drop_channels(picks)

# how many epochs per channel after channel specific epoch rejection
nave_per_channel = getattr(self, "nave_per_channel", None)

# remove dropped channel types from reject and flat
if getattr(self, "reject", None) is not None:
# use list(self.reject) to avoid RuntimeError for changing dictionary size
Expand All @@ -518,6 +523,13 @@ def pick(self, picks, exclude=(), *, verbose=None):
if ch_type not in self:
del self.flat[ch_type]

if nave_per_channel is not None:
# self is the epochs object, always has the same number of channels
nave_dict = dict(zip(self.info["ch_names"], nave_per_channel))
self.nave_per_channel = np.array(
[nave_dict[ch] for ch in ch_names if ch in nave_dict]
)

return self

def reorder_channels(self, ch_names):
Expand Down Expand Up @@ -829,7 +841,7 @@ def interpolate_bads(
exclude=(),
verbose=None,
):
"""Interpolate bad MEG and EEG channels.
"""Interpolate bad MEG, EEG and fNIRS channels.

Operates in place.

Expand Down
8 changes: 7 additions & 1 deletion mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -2363,7 +2363,13 @@ def whiten_evoked(
noise_cov, evoked.info, picks=picks, rank=rank, scalings=scalings
)

evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks])
# create matrix based on nave_per_channel if attribute exists
if hasattr(evoked, "nave_per_channel") and evoked.nave_per_channel is not None:
noise_scaling_matrix = np.diag(np.sqrt(evoked.nave_per_channel[picks]))
evoked.data[picks] = noise_scaling_matrix @ np.dot(W, evoked.data[picks])
else:
evoked.data[picks] = np.sqrt(evoked.nave) * np.dot(W, evoked.data[picks])

return evoked


Expand Down
70 changes: 66 additions & 4 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,57 @@ def __init__(
self._check_consistency()
self.set_annotations(annotations, on_missing="ignore")

def drop_bad_epochs_by_channel(self, reject_mask=None):
"""Drop bad epochs for individual channels.

Parameters
----------
reject_mask : np.ndarray, shape (n_epochs, n_channels) | None
Boolean mask where True indicates a bad epoch for that channel.
If None, no epochs are marked as bad.

Returns
-------
epochs : instance of Epochs
The epochs with bad epochs marked with NaNs. Operates in-place.
"""
if reject_mask is None:
return self

if not self.preload:
raise ValueError("Epochs must be preloaded.")

data = self.get_data()
n_epochs, n_channels, n_times = data.shape

if reject_mask.shape != (n_epochs, n_channels):
raise ValueError(
f"reject_mask must have shape ({n_epochs}, {n_channels}), "
f"got {reject_mask.shape}"
)

# mask needs to contain integer or boolean
if not np.issubdtype(reject_mask.dtype, np.bool_):
reject_mask = reject_mask.astype(bool)

# Set bad epochs to NaN
# We need to add a dimension for time to the array
mask_3d = reject_mask[:, :, np.newaxis] # shape: (n_epochs, n_channels, 1)

# Broadcast to (n_epochs, n_channels, n_times) and set to NaN
data[mask_3d.repeat(n_times, axis=2)] = np.nan

# store mask for updating nave
self.reject_mask = reject_mask

# store nave per channel for updating nave
valid_epochs_per_channel = np.sum(~reject_mask, axis=0)
self.nave_per_channel = valid_epochs_per_channel

# Update data
self._data = data
return self

def _check_consistency(self):
"""Check invariants of epochs object."""
if hasattr(self, "events"):
Expand Down Expand Up @@ -1248,12 +1299,23 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment):
"""Create an evoked object from epoch data."""
info = deepcopy(info)
# don't apply baseline correction; we'll set evoked.baseline manually

# does the epoch object have a attribute nave_per_channel?
nave_per_channel = getattr(self, "nave_per_channel", None)

if nave_per_channel is None:
# Default behavior
nave = n_events
else:
# reset nave to minimum of epochs of all channel
nave = int(nave_per_channel.min())

evoked = EvokedArray(
data,
info,
tmin=self.times[0],
comment=comment,
nave=n_events,
nave=nave,
kind=kind,
baseline=None,
)
Expand All @@ -1263,9 +1325,9 @@ def _evoked_from_epoch_data(self, data, info, picks, n_events, kind, comment):
# due to numerical precision issues
evoked._set_times(self.times.copy())

# pick channels
picks = _picks_to_idx(self.info, picks, "data_or_ica", ())
ch_names = [evoked.ch_names[p] for p in picks]
# Apply picks
picks_idx = _picks_to_idx(info, picks, "data_or_ica", ())
ch_names = [evoked.ch_names[p] for p in picks_idx]
evoked.pick(ch_names)

if len(evoked.info["ch_names"]) == 0:
Expand Down
77 changes: 75 additions & 2 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def _assert_drop_log_types(drop_log):
)


@pytest.mark.skip()
def test_reject():
"""Test epochs rejection."""
raw, events, _ = _get_data()
Expand All @@ -501,6 +502,7 @@ def test_reject():
assert len(events) == 7
selection = np.arange(3)
drop_log = ((),) * 3 + (("MEG 2443",),) * 4

_assert_drop_log_types(drop_log)
pytest.raises(TypeError, pick_types, raw)
picks_meg = pick_types(raw.info, meg=True, eeg=False)
Expand All @@ -511,7 +513,7 @@ def test_reject():
events,
event_id,
tmin,
tmax,
tmax, # modernize pytest
picks=picks,
preload=False,
reject="foo",
Expand Down Expand Up @@ -553,6 +555,7 @@ def my_reject_2(epoch_data):
return len(bad_idxs), reasons

for val in (-1, -2): # protect against older MNE-C types
# warning
for kwarg in ("reject", "flat"):
pytest.raises(
ValueError,
Expand Down Expand Up @@ -587,7 +590,7 @@ def my_reject_2(epoch_data):

# Check if callable returns a tuple with reasons
bad_types = [my_reject_2, ("HiHi"), (1, 1), None]
for val in bad_types: # protect against bad types
for val in bad_types: # protect against bad typesb
for kwarg in ("reject", "flat"):
with pytest.raises(
TypeError,
Expand Down Expand Up @@ -5273,3 +5276,73 @@ def test_empty_error(method, epochs_empty):
pytest.importorskip("pandas")
with pytest.raises(RuntimeError, match="is empty."):
getattr(epochs_empty.copy(), method[0])(**method[1])


def test_drop_bad_epochs_by_channel():
"""Test channel-specific epoch rejection."""
# load raw and events data without loading data to disk
raw, ev, _ = _get_data(preload=False)
ep = Epochs(raw, ev, tmin=0, tmax=0.1, baseline=(0, 0), preload=False)

# extract shape to set up reject mask (can't use shape as it loads the data)
n_epochs = len(ep.events) # number of epochs
n_channels = len(ep.ch_names) # number of channels

# create a dummy reject mask with correct shape
reject_mask_dummy = np.zeros((n_epochs, n_channels))

# should throw an error
with pytest.raises(ValueError, match="must be preloaded"):
ep.drop_bad_epochs_by_channel(reject_mask_dummy)

# load data
ep.load_data()

# test if reject_mask == None returns epochs
assert ep == ep.drop_bad_epochs_by_channel(None)

# set epochs to bad in reject mask
reject_mask = np.zeros((n_epochs, n_channels), dtype=bool) # all epochs are good
reject_mask[1, 0] = True # second epoch, first channel -> bad
reject_mask[1:, 1] = True # all epochs from channel two are bad
reject_mask[3, 2] = True # fourth epoch, third channel -> bad

# this is a edge case, averaging throws an error because of empty channel
# realistically the user will drop the channel if all epochs are bad
# reject_mask[:, 1] = True # all epochs from channel two are bad

# drop bad epochs
ep.drop_bad_epochs_by_channel(reject_mask)

# verify bad epochs are NaN after dropping them
data = ep.get_data()
assert np.all(np.isnan(data[1, 0, :])) and np.all(np.isnan(data[3, 2, :]))
assert np.all(np.isnan(data[1:, 1, :]))

# now we should have a nave per channel attribute
# now self.nave_per_channel should be assigned
assert hasattr(ep, "nave_per_channel")

# sum over good epochs per channel
true_nave_per_channel = np.sum(~np.all(np.isnan(data), axis=2), axis=0)
assert np.all(ep.nave_per_channel == true_nave_per_channel)

# channel length must match
assert len(ep.nave_per_channel) == len(ep.ch_names)

# make sure averaging works (allowing for NaNs)
ev = ep.average()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps add smth like:
assert ev == ep.average(pick=ch_subset)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we have to test that here (I think this is covered in the test for averaging epochs)

# check if nave of evoked data is minimum of nave_per_channel of epoched data
assert ev.nave == ep.nave_per_channel.min()

# test mask that contains floats instead of bool
float_mask = reject_mask.astype(float)
ep.drop_bad_epochs_by_channel(float_mask)
data = ep.get_data()
assert np.all(np.isnan(data[1, 0, :])) and np.all(np.isnan(data[3, 2, :]))

# test wrong shape of rejection mask
bad_mask = np.zeros((n_epochs, n_channels - 1), dtype=bool)
with pytest.raises(ValueError, match="reject_mask must have shape"):
ep.drop_bad_epochs_by_channel(bad_mask)
4 changes: 2 additions & 2 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,12 +951,12 @@ def _check_combine(mode, valid=("mean", "median", "std"), axis=0):
if mode == "mean":

def fun(data):
return np.mean(data, axis=axis)
return np.nanmean(data, axis=axis)

elif mode == "std":

def fun(data):
return np.std(data, axis=axis)
return np.nanstd(data, axis=axis)

elif mode == "median" or mode == np.median:

Expand Down