99# Copyright the MNE-Python contributors.
1010
1111from copy import deepcopy
12+ from inspect import getfullargspec
1213from typing import Union
1314
1415import numpy as np
@@ -258,7 +259,15 @@ def get_data(self, picks=None, units=None, tmin=None, tmax=None):
258259
259260 @verbose
260261 def apply_function (
261- self , fun , picks = None , dtype = None , n_jobs = None , verbose = None , ** kwargs
262+ self ,
263+ fun ,
264+ picks = None ,
265+ dtype = None ,
266+ n_jobs = None ,
267+ channel_wise = True ,
268+ * ,
269+ verbose = None ,
270+ ** kwargs ,
262271 ):
263272 """Apply a function to a subset of channels.
264273
@@ -271,6 +280,9 @@ def apply_function(
271280 %(dtype_applyfun)s
272281 %(n_jobs)s Ignored if ``channel_wise=False`` as the workload
273282 is split across channels.
283+ %(channel_wise_applyfun)s
284+
285+ .. versionadded:: 1.6
274286 %(verbose)s
275287 %(kwargs_fun)s
276288
@@ -289,21 +301,55 @@ def apply_function(
289301 if dtype is not None and dtype != self ._data .dtype :
290302 self ._data = self ._data .astype (dtype )
291303
304+ args = getfullargspec (fun ).args + getfullargspec (fun ).kwonlyargs
305+ if channel_wise is False :
306+ if ("ch_idx" in args ) or ("ch_name" in args ):
307+ raise ValueError (
308+ "apply_function cannot access ch_idx or ch_name "
309+ "when channel_wise=False"
310+ )
311+ if "ch_idx" in args :
312+ logger .info ("apply_function requested to access ch_idx" )
313+ if "ch_name" in args :
314+ logger .info ("apply_function requested to access ch_name" )
315+
292316 # check the dimension of the incoming evoked data
293317 _check_option ("evoked.ndim" , self ._data .ndim , [2 ])
294318
295- parallel , p_fun , n_jobs = parallel_func (_check_fun , n_jobs )
296- if n_jobs == 1 :
297- # modify data inplace to save memory
298- for idx in picks :
299- self ._data [idx , :] = _check_fun (fun , data_in [idx , :], ** kwargs )
319+ if channel_wise :
320+ parallel , p_fun , n_jobs = parallel_func (_check_fun , n_jobs )
321+ if n_jobs == 1 :
322+ # modify data inplace to save memory
323+ for ch_idx in picks :
324+ if "ch_idx" in args :
325+ kwargs .update (ch_idx = ch_idx )
326+ if "ch_name" in args :
327+ kwargs .update (ch_name = self .info ["ch_names" ][ch_idx ])
328+ self ._data [ch_idx , :] = _check_fun (
329+ fun , data_in [ch_idx , :], ** kwargs
330+ )
331+ else :
332+ # use parallel function
333+ data_picks_new = parallel (
334+ p_fun (
335+ fun ,
336+ data_in [ch_idx , :],
337+ ** kwargs ,
338+ ** {
339+ k : v
340+ for k , v in [
341+ ("ch_name" , self .info ["ch_names" ][ch_idx ]),
342+ ("ch_idx" , ch_idx ),
343+ ]
344+ if k in args
345+ },
346+ )
347+ for ch_idx in picks
348+ )
349+ for run_idx , ch_idx in enumerate (picks ):
350+ self ._data [ch_idx , :] = data_picks_new [run_idx ]
300351 else :
301- # use parallel function
302- data_picks_new = parallel (
303- p_fun (fun , data_in [p , :], ** kwargs ) for p in picks
304- )
305- for pp , p in enumerate (picks ):
306- self ._data [p , :] = data_picks_new [pp ]
352+ self ._data [picks , :] = _check_fun (fun , data_in [picks , :], ** kwargs )
307353
308354 return self
309355
0 commit comments