Skip to content

Commit 2fe55d8

Browse files
Make peak detection (locally_exclussive, matched_filtering) faster and more accurate. (#4341)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6f0ed09 commit 2fe55d8

File tree

13 files changed

+251
-327
lines changed

13 files changed

+251
-327
lines changed

src/spikeinterface/preprocessing/motion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
method="locally_exclusive",
122122
peak_sign="neg",
123123
detect_threshold=8.0,
124-
exclude_sweep_ms=0.1,
124+
exclude_sweep_ms=0.8,
125125
radius_um=75.0,
126126
),
127127
"select_kwargs": dict(),
@@ -139,7 +139,7 @@
139139
method="locally_exclusive",
140140
peak_sign="neg",
141141
detect_threshold=8.0,
142-
exclude_sweep_ms=0.1,
142+
exclude_sweep_ms=0.8,
143143
radius_um=50,
144144
),
145145
"select_kwargs": dict(),

src/spikeinterface/sortingcomponents/clustering/tests/test_clustering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run_peaks(recording, job_kwargs):
4242
method_kwargs=dict(
4343
peak_sign="neg",
4444
detect_threshold=5,
45-
exclude_sweep_ms=0.1,
45+
exclude_sweep_ms=0.8,
4646
noise_levels=noise_levels,
4747
),
4848
job_kwargs=job_kwargs,

src/spikeinterface/sortingcomponents/matching/circus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def __init__(
626626
return_output=True,
627627
templates=None,
628628
peak_sign="neg",
629-
exclude_sweep_ms=0.1,
629+
exclude_sweep_ms=0.8,
630630
jitter_ms=0.1,
631631
detect_threshold=5,
632632
noise_levels=None,

src/spikeinterface/sortingcomponents/matching/nearest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(
3333
templates,
3434
return_output=True,
3535
peak_sign="neg",
36-
exclude_sweep_ms=0.1,
36+
exclude_sweep_ms=0.8,
3737
detect_threshold=5,
3838
noise_levels=None,
3939
detection_radius_um=100.0,
@@ -158,7 +158,7 @@ def __init__(
158158
svd_model,
159159
return_output=True,
160160
peak_sign="neg",
161-
exclude_sweep_ms=0.1,
161+
exclude_sweep_ms=0.8,
162162
detect_threshold=5,
163163
noise_levels=None,
164164
detection_radius_um=100.0,

src/spikeinterface/sortingcomponents/matching/tdc_peeler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def __init__(
9595
templates,
9696
return_output=True,
9797
peak_sign="neg",
98-
exclude_sweep_ms=0.5,
98+
exclude_sweep_ms=0.8,
9999
peak_shift_ms=0.2,
100100
detect_threshold=5,
101101
noise_levels=None,

src/spikeinterface/sortingcomponents/motion/tests/test_motion_estimation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def setup_dataset_and_peaks(cache_folder):
3030
noise_levels=get_noise_levels(recording, return_in_uV=False),
3131
peak_sign="neg",
3232
detect_threshold=5,
33-
exclude_sweep_ms=0.1,
33+
exclude_sweep_ms=1.0,
3434
return_output=True,
3535
)
3636
extract_dense_waveforms = ExtractDenseWaveforms(

src/spikeinterface/sortingcomponents/peak_detection/by_channel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class ByChannelPeakDetector(PeakDetector):
2828
Sign of the peak
2929
detect_threshold: float, default: 5
3030
Threshold, in median absolute deviations (MAD), to use to detect peaks
31-
exclude_sweep_ms: float, default: 0.1
31+
exclude_sweep_ms: float, default: 1.0
3232
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
33-
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
34-
and no larger peaks are located during the 0.1ms preceding and following the peak
33+
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
34+
and no larger peaks are located during the 1.0ms preceding and following the peak
3535
noise_levels: array or None, default: None
3636
Estimated noise levels to use, if already computed
3737
If not provide then it is estimated from a random snippet of the data
@@ -42,7 +42,7 @@ def __init__(
4242
recording,
4343
peak_sign="neg",
4444
detect_threshold=5,
45-
exclude_sweep_ms=0.1,
45+
exclude_sweep_ms=1.0,
4646
noise_levels=None,
4747
return_output=True,
4848
):
@@ -116,10 +116,10 @@ class ByChannelTorchPeakDetector(ByChannelPeakDetector):
116116
Sign of the peak
117117
detect_threshold: float, default: 5
118118
Threshold, in median absolute deviations (MAD), to use to detect peaks
119-
exclude_sweep_ms: float, default: 0.1
119+
exclude_sweep_ms: float, default: 1.0
120120
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
121-
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
122-
and no larger peaks are located during the 0.1ms preceding and following the peak
121+
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
122+
and no larger peaks are located during the 1.0ms preceding and following the peak
123123
noise_levels: array or None, default: None
124124
Estimated noise levels to use, if already computed.
125125
If not provide then it is estimated from a random snippet of the data
@@ -134,7 +134,7 @@ def __init__(
134134
recording,
135135
peak_sign="neg",
136136
detect_threshold=5,
137-
exclude_sweep_ms=0.1,
137+
exclude_sweep_ms=1.0,
138138
noise_levels=None,
139139
device=None,
140140
return_tensor=False,

src/spikeinterface/sortingcomponents/peak_detection/locally_exclusive.py

Lines changed: 76 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
recording,
5050
peak_sign="neg",
5151
detect_threshold=5,
52-
exclude_sweep_ms=0.1,
52+
exclude_sweep_ms=1.0,
5353
radius_um=50,
5454
noise_levels=None,
5555
return_output=True,
@@ -81,7 +81,8 @@ def __init__(
8181
self.neighbours_mask = self.channel_distance <= radius_um
8282

8383
def get_trace_margin(self):
84-
return self.exclude_sweep_size
84+
# the +1 in the border is important because we need peak in the border
85+
return self.exclude_sweep_size + 1
8586

8687
def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
8788
assert HAVE_NUMBA, "You need to install numba"
@@ -104,88 +105,84 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
104105
if HAVE_NUMBA:
105106
import numba
106107

108+
@numba.jit(nopython=True, parallel=False, nogil=True, fastmath=True)
107109
def detect_peaks_numba_locally_exclusive_on_chunk(
108-
traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask
110+
traces,
111+
peak_sign,
112+
abs_thresholds,
113+
exclude_sweep_size,
114+
neighbours_mask,
109115
):
116+
num_chans = traces.shape[1]
117+
num_samples = traces.shape[0]
118+
119+
do_pos = peak_sign in ("pos", "both")
120+
do_neg = peak_sign in ("neg", "both")
121+
122+
# first find peaks
123+
peak_mask = np.zeros(traces.shape, dtype="bool")
124+
for s in range(1, num_samples - 1):
125+
for chan_ind in range(num_chans):
126+
if do_neg:
127+
if (
128+
(traces[s, chan_ind] <= -abs_thresholds[chan_ind])
129+
and (traces[s, chan_ind] < traces[s - 1, chan_ind])
130+
and (traces[s, chan_ind] <= traces[s + 1, chan_ind])
131+
):
132+
peak_mask[s, chan_ind] = True
133+
134+
if do_pos:
135+
if (
136+
(traces[s, chan_ind] >= abs_thresholds[chan_ind])
137+
and (traces[s, chan_ind] > traces[s - 1, chan_ind])
138+
and (traces[s, chan_ind] >= traces[s + 1, chan_ind])
139+
):
140+
peak_mask[s, chan_ind] = True
141+
142+
samples_inds, chan_inds = np.nonzero(peak_mask)
143+
144+
npeaks = samples_inds.size
145+
keep_peak = np.ones(npeaks, dtype="bool")
146+
next_start = 0
147+
for i in range(npeaks):
148+
149+
if (samples_inds[i] < exclude_sweep_size + 1) or (
150+
samples_inds[i] >= (num_samples - exclude_sweep_size - 1)
151+
):
152+
keep_peak[i] = False
153+
continue
154+
155+
for j in range(next_start, npeaks):
156+
if i == j:
157+
continue
110158

111-
# if medians is not None:
112-
# traces = traces - medians
113-
114-
traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :]
115-
116-
if peak_sign in ("pos", "both"):
117-
peak_mask = traces_center > abs_thresholds[None, :]
118-
peak_mask = _numba_detect_peak_pos(
119-
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
120-
)
121-
122-
if peak_sign in ("neg", "both"):
123-
if peak_sign == "both":
124-
peak_mask_pos = peak_mask.copy()
125-
126-
peak_mask = traces_center < -abs_thresholds[None, :]
127-
peak_mask = _numba_detect_peak_neg(
128-
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
129-
)
130-
131-
if peak_sign == "both":
132-
peak_mask = peak_mask | peak_mask_pos
133-
134-
# Find peaks and correct for time shift
135-
peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask)
136-
peak_sample_ind += exclude_sweep_size
159+
if samples_inds[i] + exclude_sweep_size < samples_inds[j]:
160+
break
137161

138-
return peak_sample_ind, peak_chan_ind
139-
140-
@numba.jit(nopython=True, parallel=False)
141-
def _numba_detect_peak_pos(
142-
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
143-
):
144-
num_chans = traces_center.shape[1]
145-
for chan_ind in range(num_chans):
146-
for s in range(peak_mask.shape[0]):
147-
if not peak_mask[s, chan_ind]:
162+
if samples_inds[i] - exclude_sweep_size > samples_inds[j]:
163+
next_start = j
148164
continue
149-
for neighbour in range(num_chans):
150-
if not neighbours_mask[chan_ind, neighbour]:
151-
continue
152-
for i in range(exclude_sweep_size):
153-
if chan_ind != neighbour:
154-
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] >= traces_center[s, neighbour]
155-
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] > traces[s + i, neighbour]
156-
peak_mask[s, chan_ind] &= (
157-
traces_center[s, chan_ind] >= traces[exclude_sweep_size + s + i + 1, neighbour]
158-
)
159-
if not peak_mask[s, chan_ind]:
160-
break
161-
if not peak_mask[s, chan_ind]:
162-
break
163-
return peak_mask
164165

165-
@numba.jit(nopython=True, parallel=False)
166-
def _numba_detect_peak_neg(
167-
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
168-
):
169-
num_chans = traces_center.shape[1]
170-
for chan_ind in range(num_chans):
171-
for s in range(peak_mask.shape[0]):
172-
if not peak_mask[s, chan_ind]:
173-
continue
174-
for neighbour in range(num_chans):
175-
if not neighbours_mask[chan_ind, neighbour]:
176-
continue
177-
for i in range(exclude_sweep_size):
178-
if chan_ind != neighbour:
179-
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] <= traces_center[s, neighbour]
180-
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] < traces[s + i, neighbour]
181-
peak_mask[s, chan_ind] &= (
182-
traces_center[s, chan_ind] <= traces[exclude_sweep_size + s + i + 1, neighbour]
183-
)
184-
if not peak_mask[s, chan_ind]:
166+
# search for neighbors with higher amplitudes
167+
if neighbours_mask[chan_inds[i], chan_inds[j]]:
168+
# if inside spatial zone ...
169+
if abs(samples_inds[i] - samples_inds[j]) <= exclude_sweep_size:
170+
# ...and if inside tempral zone ...
171+
value_i = abs(traces[samples_inds[i], chan_inds[i]]) / abs_thresholds[chan_inds[i]]
172+
value_j = abs(traces[samples_inds[j], chan_inds[j]]) / abs_thresholds[chan_inds[j]]
173+
174+
if value_j > value_i:
175+
# ... and if smaller
176+
keep_peak[i] = False
177+
break
178+
if (value_j == value_i) & (samples_inds[i] > samples_inds[j]):
179+
# ... equal but after
180+
keep_peak[i] = False
185181
break
186-
if not peak_mask[s, chan_ind]:
187-
break
188-
return peak_mask
182+
183+
samples_inds, chan_inds = samples_inds[keep_peak], chan_inds[keep_peak]
184+
185+
return samples_inds, chan_inds
189186

190187

191188
class LocallyExclusiveTorchPeakDetector(ByChannelTorchPeakDetector):
@@ -205,7 +202,7 @@ def __init__(
205202
recording,
206203
peak_sign="neg",
207204
detect_threshold=5,
208-
exclude_sweep_ms=0.1,
205+
exclude_sweep_ms=1.0,
209206
noise_levels=None,
210207
device=None,
211208
radius_um=50,
@@ -275,7 +272,7 @@ def __init__(
275272
recording,
276273
peak_sign="neg",
277274
detect_threshold=5,
278-
exclude_sweep_ms=0.1,
275+
exclude_sweep_ms=1.0,
279276
radius_um=50,
280277
noise_levels=None,
281278
opencl_context_kwargs={},

0 commit comments

Comments
 (0)