@@ -49,7 +49,7 @@ def __init__(
4949 recording ,
5050 peak_sign = "neg" ,
5151 detect_threshold = 5 ,
52- exclude_sweep_ms = 0.1 ,
52+ exclude_sweep_ms = 1.0 ,
5353 radius_um = 50 ,
5454 noise_levels = None ,
5555 return_output = True ,
@@ -81,7 +81,8 @@ def __init__(
8181 self .neighbours_mask = self .channel_distance <= radius_um
8282
8383 def get_trace_margin (self ):
84- return self .exclude_sweep_size
84+ # the +1 in the border is important because we need peak in the border
85+ return self .exclude_sweep_size + 1
8586
8687 def compute (self , traces , start_frame , end_frame , segment_index , max_margin ):
8788 assert HAVE_NUMBA , "You need to install numba"
@@ -104,88 +105,84 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
104105if HAVE_NUMBA :
105106 import numba
106107
108+ @numba .jit (nopython = True , parallel = False , nogil = True , fastmath = True )
107109 def detect_peaks_numba_locally_exclusive_on_chunk (
108- traces , peak_sign , abs_thresholds , exclude_sweep_size , neighbours_mask
110+ traces ,
111+ peak_sign ,
112+ abs_thresholds ,
113+ exclude_sweep_size ,
114+ neighbours_mask ,
109115 ):
116+ num_chans = traces .shape [1 ]
117+ num_samples = traces .shape [0 ]
118+
119+ do_pos = peak_sign in ("pos" , "both" )
120+ do_neg = peak_sign in ("neg" , "both" )
121+
122+ # first find peaks
123+ peak_mask = np .zeros (traces .shape , dtype = "bool" )
124+ for s in range (1 , num_samples - 1 ):
125+ for chan_ind in range (num_chans ):
126+ if do_neg :
127+ if (
128+ (traces [s , chan_ind ] <= - abs_thresholds [chan_ind ])
129+ and (traces [s , chan_ind ] < traces [s - 1 , chan_ind ])
130+ and (traces [s , chan_ind ] <= traces [s + 1 , chan_ind ])
131+ ):
132+ peak_mask [s , chan_ind ] = True
133+
134+ if do_pos :
135+ if (
136+ (traces [s , chan_ind ] >= abs_thresholds [chan_ind ])
137+ and (traces [s , chan_ind ] > traces [s - 1 , chan_ind ])
138+ and (traces [s , chan_ind ] >= traces [s + 1 , chan_ind ])
139+ ):
140+ peak_mask [s , chan_ind ] = True
141+
142+ samples_inds , chan_inds = np .nonzero (peak_mask )
143+
144+ npeaks = samples_inds .size
145+ keep_peak = np .ones (npeaks , dtype = "bool" )
146+ next_start = 0
147+ for i in range (npeaks ):
148+
149+ if (samples_inds [i ] < exclude_sweep_size + 1 ) or (
150+ samples_inds [i ] >= (num_samples - exclude_sweep_size - 1 )
151+ ):
152+ keep_peak [i ] = False
153+ continue
154+
155+ for j in range (next_start , npeaks ):
156+ if i == j :
157+ continue
110158
111- # if medians is not None:
112- # traces = traces - medians
113-
114- traces_center = traces [exclude_sweep_size :- exclude_sweep_size , :]
115-
116- if peak_sign in ("pos" , "both" ):
117- peak_mask = traces_center > abs_thresholds [None , :]
118- peak_mask = _numba_detect_peak_pos (
119- traces , traces_center , peak_mask , exclude_sweep_size , abs_thresholds , peak_sign , neighbours_mask
120- )
121-
122- if peak_sign in ("neg" , "both" ):
123- if peak_sign == "both" :
124- peak_mask_pos = peak_mask .copy ()
125-
126- peak_mask = traces_center < - abs_thresholds [None , :]
127- peak_mask = _numba_detect_peak_neg (
128- traces , traces_center , peak_mask , exclude_sweep_size , abs_thresholds , peak_sign , neighbours_mask
129- )
130-
131- if peak_sign == "both" :
132- peak_mask = peak_mask | peak_mask_pos
133-
134- # Find peaks and correct for time shift
135- peak_sample_ind , peak_chan_ind = np .nonzero (peak_mask )
136- peak_sample_ind += exclude_sweep_size
159+ if samples_inds [i ] + exclude_sweep_size < samples_inds [j ]:
160+ break
137161
138- return peak_sample_ind , peak_chan_ind
139-
140- @numba .jit (nopython = True , parallel = False )
141- def _numba_detect_peak_pos (
142- traces , traces_center , peak_mask , exclude_sweep_size , abs_thresholds , peak_sign , neighbours_mask
143- ):
144- num_chans = traces_center .shape [1 ]
145- for chan_ind in range (num_chans ):
146- for s in range (peak_mask .shape [0 ]):
147- if not peak_mask [s , chan_ind ]:
162+ if samples_inds [i ] - exclude_sweep_size > samples_inds [j ]:
163+ next_start = j
148164 continue
149- for neighbour in range (num_chans ):
150- if not neighbours_mask [chan_ind , neighbour ]:
151- continue
152- for i in range (exclude_sweep_size ):
153- if chan_ind != neighbour :
154- peak_mask [s , chan_ind ] &= traces_center [s , chan_ind ] >= traces_center [s , neighbour ]
155- peak_mask [s , chan_ind ] &= traces_center [s , chan_ind ] > traces [s + i , neighbour ]
156- peak_mask [s , chan_ind ] &= (
157- traces_center [s , chan_ind ] >= traces [exclude_sweep_size + s + i + 1 , neighbour ]
158- )
159- if not peak_mask [s , chan_ind ]:
160- break
161- if not peak_mask [s , chan_ind ]:
162- break
163- return peak_mask
164165
165- @numba .jit (nopython = True , parallel = False )
166- def _numba_detect_peak_neg (
167- traces , traces_center , peak_mask , exclude_sweep_size , abs_thresholds , peak_sign , neighbours_mask
168- ):
169- num_chans = traces_center .shape [1 ]
170- for chan_ind in range (num_chans ):
171- for s in range (peak_mask .shape [0 ]):
172- if not peak_mask [s , chan_ind ]:
173- continue
174- for neighbour in range (num_chans ):
175- if not neighbours_mask [chan_ind , neighbour ]:
176- continue
177- for i in range (exclude_sweep_size ):
178- if chan_ind != neighbour :
179- peak_mask [s , chan_ind ] &= traces_center [s , chan_ind ] <= traces_center [s , neighbour ]
180- peak_mask [s , chan_ind ] &= traces_center [s , chan_ind ] < traces [s + i , neighbour ]
181- peak_mask [s , chan_ind ] &= (
182- traces_center [s , chan_ind ] <= traces [exclude_sweep_size + s + i + 1 , neighbour ]
183- )
184- if not peak_mask [s , chan_ind ]:
166+ # search for neighbors with higher amplitudes
167+ if neighbours_mask [chan_inds [i ], chan_inds [j ]]:
168+ # if inside spatial zone ...
169+ if abs (samples_inds [i ] - samples_inds [j ]) <= exclude_sweep_size :
170+ # ...and if inside tempral zone ...
171+ value_i = abs (traces [samples_inds [i ], chan_inds [i ]]) / abs_thresholds [chan_inds [i ]]
172+ value_j = abs (traces [samples_inds [j ], chan_inds [j ]]) / abs_thresholds [chan_inds [j ]]
173+
174+ if value_j > value_i :
175+ # ... and if smaller
176+ keep_peak [i ] = False
177+ break
178+ if (value_j == value_i ) & (samples_inds [i ] > samples_inds [j ]):
179+ # ... equal but after
180+ keep_peak [i ] = False
185181 break
186- if not peak_mask [s , chan_ind ]:
187- break
188- return peak_mask
182+
183+ samples_inds , chan_inds = samples_inds [keep_peak ], chan_inds [keep_peak ]
184+
185+ return samples_inds , chan_inds
189186
190187
191188class LocallyExclusiveTorchPeakDetector (ByChannelTorchPeakDetector ):
@@ -205,7 +202,7 @@ def __init__(
205202 recording ,
206203 peak_sign = "neg" ,
207204 detect_threshold = 5 ,
208- exclude_sweep_ms = 0.1 ,
205+ exclude_sweep_ms = 1.0 ,
209206 noise_levels = None ,
210207 device = None ,
211208 radius_um = 50 ,
@@ -275,7 +272,7 @@ def __init__(
275272 recording ,
276273 peak_sign = "neg" ,
277274 detect_threshold = 5 ,
278- exclude_sweep_ms = 0.1 ,
275+ exclude_sweep_ms = 1.0 ,
279276 radius_um = 50 ,
280277 noise_levels = None ,
281278 opencl_context_kwargs = {},
0 commit comments