Skip to content

Commit 87df00d

Browse files
ENH: make apply_function aware of channel index (#12206)
Co-authored-by: Mathieu Scheltienne <[email protected]>
1 parent e6b49ea commit 87df00d

File tree

9 files changed

+236
-24
lines changed

9 files changed

+236
-24
lines changed

doc/changes/devel/12206.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix bug in :meth:`mne.Epochs.apply_function` where data was handed down incorrectly in parallel processing, by `Dominik Welke`_.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Custom functions applied via :meth:`mne.io.Raw.apply_function`, :meth:`mne.Epochs.apply_function` or :meth:`mne.Evoked.apply_function` can now use ``ch_idx`` or ``ch_name`` to get access to the currently processed channel during channel wise processing.
2+
3+
:meth:`mne.Evoked.apply_function` can now also work on full data array instead of just channel wise, analogous to :meth:`mne.io.Raw.apply_function` and :meth:`mne.Epochs.apply_function`, by `Dominik Welke`_.

mne/epochs.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import Counter
1717
from copy import deepcopy
1818
from functools import partial
19+
from inspect import getfullargspec
1920

2021
import numpy as np
2122
from scipy.interpolate import interp1d
@@ -1972,22 +1973,52 @@ def apply_function(
19721973
if dtype is not None and dtype != self._data.dtype:
19731974
self._data = self._data.astype(dtype)
19741975

1976+
args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
1977+
if channel_wise is False:
1978+
if ("ch_idx" in args) or ("ch_name" in args):
1979+
raise ValueError(
1980+
"apply_function cannot access ch_idx or ch_name "
1981+
"when channel_wise=False"
1982+
)
1983+
if "ch_idx" in args:
1984+
logger.info("apply_function requested to access ch_idx")
1985+
if "ch_name" in args:
1986+
logger.info("apply_function requested to access ch_name")
1987+
19751988
if channel_wise:
19761989
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
19771990
if n_jobs == 1:
1978-
_fun = partial(_check_fun, fun, **kwargs)
1991+
_fun = partial(_check_fun, fun)
19791992
# modify data inplace to save memory
1980-
for idx in picks:
1981-
self._data[:, idx, :] = np.apply_along_axis(
1982-
_fun, -1, data_in[:, idx, :]
1993+
for ch_idx in picks:
1994+
if "ch_idx" in args:
1995+
kwargs.update(ch_idx=ch_idx)
1996+
if "ch_name" in args:
1997+
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
1998+
self._data[:, ch_idx, :] = np.apply_along_axis(
1999+
_fun, -1, data_in[:, ch_idx, :], **kwargs
19832000
)
19842001
else:
19852002
# use parallel function
2003+
_fun = partial(np.apply_along_axis, fun, -1)
19862004
data_picks_new = parallel(
1987-
p_fun(fun, data_in[:, p, :], **kwargs) for p in picks
2005+
p_fun(
2006+
_fun,
2007+
data_in[:, ch_idx, :],
2008+
**kwargs,
2009+
**{
2010+
k: v
2011+
for k, v in [
2012+
("ch_name", self.info["ch_names"][ch_idx]),
2013+
("ch_idx", ch_idx),
2014+
]
2015+
if k in args
2016+
},
2017+
)
2018+
for ch_idx in picks
19882019
)
1989-
for pp, p in enumerate(picks):
1990-
self._data[:, p, :] = data_picks_new[pp]
2020+
for run_idx, ch_idx in enumerate(picks):
2021+
self._data[:, ch_idx, :] = data_picks_new[run_idx]
19912022
else:
19922023
self._data = _check_fun(fun, data_in, **kwargs)
19932024

mne/evoked.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Copyright the MNE-Python contributors.
1010

1111
from copy import deepcopy
12+
from inspect import getfullargspec
1213
from typing import Union
1314

1415
import numpy as np
@@ -258,7 +259,15 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None):
258259

259260
@verbose
260261
def apply_function(
261-
self, fun, picks=None, dtype=None, n_jobs=None, verbose=None, **kwargs
262+
self,
263+
fun,
264+
picks=None,
265+
dtype=None,
266+
n_jobs=None,
267+
channel_wise=True,
268+
*,
269+
verbose=None,
270+
**kwargs,
262271
):
263272
"""Apply a function to a subset of channels.
264273
@@ -271,6 +280,9 @@ def apply_function(
271280
%(dtype_applyfun)s
272281
%(n_jobs)s Ignored if ``channel_wise=False`` as the workload
273282
is split across channels.
283+
%(channel_wise_applyfun)s
284+
285+
.. versionadded:: 1.6
274286
%(verbose)s
275287
%(kwargs_fun)s
276288
@@ -289,21 +301,55 @@ def apply_function(
289301
if dtype is not None and dtype != self._data.dtype:
290302
self._data = self._data.astype(dtype)
291303

304+
args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
305+
if channel_wise is False:
306+
if ("ch_idx" in args) or ("ch_name" in args):
307+
raise ValueError(
308+
"apply_function cannot access ch_idx or ch_name "
309+
"when channel_wise=False"
310+
)
311+
if "ch_idx" in args:
312+
logger.info("apply_function requested to access ch_idx")
313+
if "ch_name" in args:
314+
logger.info("apply_function requested to access ch_name")
315+
292316
# check the dimension of the incoming evoked data
293317
_check_option("evoked.ndim", self._data.ndim, [2])
294318

295-
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
296-
if n_jobs == 1:
297-
# modify data inplace to save memory
298-
for idx in picks:
299-
self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs)
319+
if channel_wise:
320+
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
321+
if n_jobs == 1:
322+
# modify data inplace to save memory
323+
for ch_idx in picks:
324+
if "ch_idx" in args:
325+
kwargs.update(ch_idx=ch_idx)
326+
if "ch_name" in args:
327+
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
328+
self._data[ch_idx, :] = _check_fun(
329+
fun, data_in[ch_idx, :], **kwargs
330+
)
331+
else:
332+
# use parallel function
333+
data_picks_new = parallel(
334+
p_fun(
335+
fun,
336+
data_in[ch_idx, :],
337+
**kwargs,
338+
**{
339+
k: v
340+
for k, v in [
341+
("ch_name", self.info["ch_names"][ch_idx]),
342+
("ch_idx", ch_idx),
343+
]
344+
if k in args
345+
},
346+
)
347+
for ch_idx in picks
348+
)
349+
for run_idx, ch_idx in enumerate(picks):
350+
self._data[ch_idx, :] = data_picks_new[run_idx]
300351
else:
301-
# use parallel function
302-
data_picks_new = parallel(
303-
p_fun(fun, data_in[p, :], **kwargs) for p in picks
304-
)
305-
for pp, p in enumerate(picks):
306-
self._data[p, :] = data_picks_new[pp]
352+
self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs)
307353

308354
return self
309355

mne/io/base.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from copy import deepcopy
1919
from dataclasses import dataclass, field
2020
from datetime import timedelta
21+
from inspect import getfullargspec
2122

2223
import numpy as np
2324

@@ -1087,19 +1088,50 @@ def apply_function(
10871088
if dtype is not None and dtype != self._data.dtype:
10881089
self._data = self._data.astype(dtype)
10891090

1091+
args = getfullargspec(fun).args + getfullargspec(fun).kwonlyargs
1092+
if channel_wise is False:
1093+
if ("ch_idx" in args) or ("ch_name" in args):
1094+
raise ValueError(
1095+
"apply_function cannot access ch_idx or ch_name "
1096+
"when channel_wise=False"
1097+
)
1098+
if "ch_idx" in args:
1099+
logger.info("apply_function requested to access ch_idx")
1100+
if "ch_name" in args:
1101+
logger.info("apply_function requested to access ch_name")
1102+
10901103
if channel_wise:
10911104
parallel, p_fun, n_jobs = parallel_func(_check_fun, n_jobs)
10921105
if n_jobs == 1:
10931106
# modify data inplace to save memory
1094-
for idx in picks:
1095-
self._data[idx, :] = _check_fun(fun, data_in[idx, :], **kwargs)
1107+
for ch_idx in picks:
1108+
if "ch_idx" in args:
1109+
kwargs.update(ch_idx=ch_idx)
1110+
if "ch_name" in args:
1111+
kwargs.update(ch_name=self.info["ch_names"][ch_idx])
1112+
self._data[ch_idx, :] = _check_fun(
1113+
fun, data_in[ch_idx, :], **kwargs
1114+
)
10961115
else:
10971116
# use parallel function
10981117
data_picks_new = parallel(
1099-
p_fun(fun, data_in[p], **kwargs) for p in picks
1118+
p_fun(
1119+
fun,
1120+
data_in[ch_idx],
1121+
**kwargs,
1122+
**{
1123+
k: v
1124+
for k, v in [
1125+
("ch_name", self.info["ch_names"][ch_idx]),
1126+
("ch_idx", ch_idx),
1127+
]
1128+
if k in args
1129+
},
1130+
)
1131+
for ch_idx in picks
11001132
)
1101-
for pp, p in enumerate(picks):
1102-
self._data[p, :] = data_picks_new[pp]
1133+
for run_idx, ch_idx in enumerate(picks):
1134+
self._data[ch_idx, :] = data_picks_new[run_idx]
11031135
else:
11041136
self._data[picks, :] = _check_fun(fun, data_in[picks, :], **kwargs)
11051137

mne/io/tests/test_apply_function.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,32 @@ def test_apply_function_verbose():
6363
assert out is raw
6464
raw.apply_function(printer, verbose=True)
6565
assert sio.getvalue().count("\n") == n_chan
66+
67+
68+
def test_apply_function_ch_access():
69+
"""Test apply_function is able to access channel idx."""
70+
71+
def _bad_ch_idx(x, ch_idx):
72+
assert x[0] == ch_idx
73+
return x
74+
75+
def _bad_ch_name(x, ch_name):
76+
assert isinstance(ch_name, str)
77+
assert x[0] == float(ch_name)
78+
return x
79+
80+
data = np.full((2, 10), np.arange(2).reshape(-1, 1))
81+
raw = RawArray(data, create_info(2, 1.0, "mag"))
82+
83+
# test ch_idx access in both code paths (parallel / 1 job)
84+
raw.apply_function(_bad_ch_idx)
85+
raw.apply_function(_bad_ch_idx, n_jobs=2)
86+
raw.apply_function(_bad_ch_name)
87+
raw.apply_function(_bad_ch_name, n_jobs=2)
88+
89+
# test input catches
90+
with pytest.raises(
91+
ValueError,
92+
match="cannot access.*when channel_wise=False",
93+
):
94+
raw.apply_function(_bad_ch_idx, channel_wise=False)

mne/tests/test_epochs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4764,6 +4764,39 @@ def fun(data):
47644764
assert_array_equal(out.get_data(non_picks), epochs.get_data(non_picks))
47654765

47664766

4767+
def test_apply_function_epo_ch_access():
4768+
"""Test ch-access within apply function to epoch objects."""
4769+
4770+
def _bad_ch_idx(x, ch_idx):
4771+
assert x.shape == (46,)
4772+
assert x[0] == ch_idx
4773+
return x
4774+
4775+
def _bad_ch_name(x, ch_name):
4776+
assert x.shape == (46,)
4777+
assert isinstance(ch_name, str)
4778+
assert x[0] == float(ch_name)
4779+
return x
4780+
4781+
data = np.full((2, 100), np.arange(2).reshape(-1, 1))
4782+
raw = RawArray(data, create_info(2, 1.0, "mag"))
4783+
ev = np.array([[0, 0, 33], [50, 0, 33]])
4784+
ep = Epochs(raw, ev, tmin=0, tmax=45, baseline=None, preload=True)
4785+
4786+
# test ch_idx access in both code paths (parallel / 1 job)
4787+
ep.apply_function(_bad_ch_idx)
4788+
ep.apply_function(_bad_ch_idx, n_jobs=2)
4789+
ep.apply_function(_bad_ch_name)
4790+
ep.apply_function(_bad_ch_name, n_jobs=2)
4791+
4792+
# test input catches
4793+
with pytest.raises(
4794+
ValueError,
4795+
match="cannot access.*when channel_wise=False",
4796+
):
4797+
ep.apply_function(_bad_ch_idx, channel_wise=False)
4798+
4799+
47674800
@testing.requires_testing_data
47684801
def test_add_channels_picks():
47694802
"""Check that add_channels properly deals with picks."""

mne/tests/test_evoked.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,33 @@ def fun(data, multiplier):
959959
applied = evoked.apply_function(fun, n_jobs=None, multiplier=mult)
960960
assert np.shape(applied.data) == np.shape(evoked_data)
961961
assert np.equal(applied.data, evoked_data * mult).all()
962+
963+
964+
def test_apply_function_evk_ch_access():
965+
"""Check ch-access within the apply_function method for evoked data."""
966+
967+
def _bad_ch_idx(x, ch_idx):
968+
assert x[0] == ch_idx
969+
return x
970+
971+
def _bad_ch_name(x, ch_name):
972+
assert isinstance(ch_name, str)
973+
assert x[0] == float(ch_name)
974+
return x
975+
976+
# create fake evoked data to use for checking apply_function
977+
data = np.full((2, 100), np.arange(2).reshape(-1, 1))
978+
evoked = EvokedArray(data, create_info(2, 1000.0, "eeg"))
979+
980+
# test ch_idx access in both code paths (parallel / 1 job)
981+
evoked.apply_function(_bad_ch_idx)
982+
evoked.apply_function(_bad_ch_idx, n_jobs=2)
983+
evoked.apply_function(_bad_ch_name)
984+
evoked.apply_function(_bad_ch_name, n_jobs=2)
985+
986+
# test input catches
987+
with pytest.raises(
988+
ValueError,
989+
match="cannot access.*when channel_wise=False",
990+
):
991+
evoked.apply_function(_bad_ch_idx, channel_wise=False)

mne/utils/docs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,13 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
15861586
fun has to be a timeseries (:class:`numpy.ndarray`). The function must
15871587
operate on an array of shape ``(n_times,)`` {}.
15881588
The function must return an :class:`~numpy.ndarray` shaped like its input.
1589+
1590+
.. note::
1591+
If ``channel_wise=True``, one can optionally access the index and/or the
1592+
name of the currently processed channel within the applied function.
1593+
This can enable tailored computations for different channels.
1594+
To use this feature, add ``ch_idx`` and/or ``ch_name`` as
1595+
additional argument(s) to your function definition.
15891596
"""
15901597
docdict["fun_applyfun"] = applyfun_fun_base.format(
15911598
" if ``channel_wise=True`` and ``(len(picks), n_times)`` otherwise"

0 commit comments

Comments
 (0)