Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
method="locally_exclusive",
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
radius_um=75.0,
),
"select_kwargs": dict(),
Expand All @@ -139,7 +139,7 @@
method="locally_exclusive",
peak_sign="neg",
detect_threshold=8.0,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
radius_um=50,
),
"select_kwargs": dict(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_peaks(recording, job_kwargs):
method_kwargs=dict(
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
noise_levels=noise_levels,
),
job_kwargs=job_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/matching/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def __init__(
return_output=True,
templates=None,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
jitter_ms=0.1,
detect_threshold=5,
noise_levels=None,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
templates,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
detect_threshold=5,
noise_levels=None,
detection_radius_um=100.0,
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(
svd_model,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.1,
exclude_sweep_ms=0.8,
detect_threshold=5,
noise_levels=None,
detection_radius_um=100.0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(
templates,
return_output=True,
peak_sign="neg",
exclude_sweep_ms=0.5,
exclude_sweep_ms=0.8,
peak_shift_ms=0.2,
detect_threshold=5,
noise_levels=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def setup_dataset_and_peaks(cache_folder):
noise_levels=get_noise_levels(recording, return_in_uV=False),
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
return_output=True,
)
extract_dense_waveforms = ExtractDenseWaveforms(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class ByChannelPeakDetector(PeakDetector):
Sign of the peak
detect_threshold: float, default: 5
Threshold, in median absolute deviations (MAD), to use to detect peaks
exclude_sweep_ms: float, default: 0.1
exclude_sweep_ms: float, default: 1.0
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 0.1ms preceding and following the peak
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 1.0ms preceding and following the peak
noise_levels: array or None, default: None
Estimated noise levels to use, if already computed
If not provide then it is estimated from a random snippet of the data
Expand All @@ -42,7 +42,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
return_output=True,
):
Expand Down Expand Up @@ -116,10 +116,10 @@ class ByChannelTorchPeakDetector(ByChannelPeakDetector):
Sign of the peak
detect_threshold: float, default: 5
Threshold, in median absolute deviations (MAD), to use to detect peaks
exclude_sweep_ms: float, default: 0.1
exclude_sweep_ms: float, default: 1.0
Time, in ms, during which the peak is isolated. Exclusive param with exclude_sweep_size
For example, if `exclude_sweep_ms` is 0.1, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 0.1ms preceding and following the peak
For example, if `exclude_sweep_ms` is 1.0, a peak is detected if a sample crosses the threshold,
and no larger peaks are located during the 1.0ms preceding and following the peak
noise_levels: array or None, default: None
Estimated noise levels to use, if already computed.
If not provide then it is estimated from a random snippet of the data
Expand All @@ -134,7 +134,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
device=None,
return_tensor=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
radius_um=50,
noise_levels=None,
return_output=True,
Expand Down Expand Up @@ -81,7 +81,8 @@ def __init__(
self.neighbours_mask = self.channel_distance <= radius_um

def get_trace_margin(self):
return self.exclude_sweep_size
# the +1 in the border is important because we need peak in the border
return self.exclude_sweep_size + 1

def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
assert HAVE_NUMBA, "You need to install numba"
Expand All @@ -104,88 +105,78 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
if HAVE_NUMBA:
import numba

@numba.jit(nopython=True, parallel=False, nogil=True, fastmath=True)
def detect_peaks_numba_locally_exclusive_on_chunk(
traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask
traces, peak_sign, abs_thresholds, exclude_sweep_size, neighbours_mask,
):
num_chans = traces.shape[1]
num_samples = traces.shape[0]


do_pos = peak_sign in ("pos", "both")
do_neg = peak_sign in ("neg", "both")

# first find peaks
peak_mask = np.zeros(traces.shape, dtype="bool")
for s in range(1, num_samples - 1):
for chan_ind in range(num_chans):
if do_neg:
if (traces[s, chan_ind] <= -abs_thresholds[chan_ind]) and \
(traces[s, chan_ind] < traces[s-1, chan_ind]) and \
(traces[s, chan_ind] <= traces[s+1, chan_ind]):
peak_mask[s, chan_ind] = True

if do_pos :
if (traces[s, chan_ind] >= abs_thresholds[chan_ind]) and \
(traces[s, chan_ind] > traces[s-1, chan_ind]) and \
(traces[s, chan_ind] >= traces[s+1, chan_ind]):
peak_mask[s, chan_ind] = True

samples_inds, chan_inds = np.nonzero(peak_mask)

npeaks = samples_inds.size
keep_peak = np.ones(npeaks, dtype="bool")
next_start = 0
for i in range(npeaks):

if (samples_inds[i] < exclude_sweep_size + 1) or (samples_inds[i]>= (num_samples - exclude_sweep_size - 1)):
keep_peak[i] = False
continue

for j in range(next_start, npeaks):
if i == j:
continue

# if medians is not None:
# traces = traces - medians

traces_center = traces[exclude_sweep_size:-exclude_sweep_size, :]

if peak_sign in ("pos", "both"):
peak_mask = traces_center > abs_thresholds[None, :]
peak_mask = _numba_detect_peak_pos(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
)
if samples_inds[i] + exclude_sweep_size < samples_inds[j]:
break

if peak_sign in ("neg", "both"):
if peak_sign == "both":
peak_mask_pos = peak_mask.copy()
if samples_inds[i] - exclude_sweep_size > samples_inds[j]:
next_start = j
continue

# search for neighbors with higher amplitudes
if neighbours_mask[chan_inds[i], chan_inds[j]]:
# if inside spatial zone ...
if abs(samples_inds[i] - samples_inds[j]) <= exclude_sweep_size:
# ...and if inside tempral zone ...
value_i = abs(traces[samples_inds[i], chan_inds[i]]) / abs_thresholds[chan_inds[i]]
value_j = abs(traces[samples_inds[j], chan_inds[j]]) / abs_thresholds[chan_inds[j]]

if (value_j > value_i):
# ... and if smaller
keep_peak[i] = False
break
if ((value_j == value_i) & (samples_inds[i] > samples_inds[j])):
# ... equal but after
keep_peak[i] = False
break

peak_mask = traces_center < -abs_thresholds[None, :]
peak_mask = _numba_detect_peak_neg(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
)

if peak_sign == "both":
peak_mask = peak_mask | peak_mask_pos
samples_inds, chan_inds = samples_inds[keep_peak], chan_inds[keep_peak]

# Find peaks and correct for time shift
peak_sample_ind, peak_chan_ind = np.nonzero(peak_mask)
peak_sample_ind += exclude_sweep_size
return samples_inds, chan_inds

return peak_sample_ind, peak_chan_ind

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_pos(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
num_chans = traces_center.shape[1]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[0]):
if not peak_mask[s, chan_ind]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind, neighbour]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] >= traces_center[s, neighbour]
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] > traces[s + i, neighbour]
peak_mask[s, chan_ind] &= (
traces_center[s, chan_ind] >= traces[exclude_sweep_size + s + i + 1, neighbour]
)
if not peak_mask[s, chan_ind]:
break
if not peak_mask[s, chan_ind]:
break
return peak_mask

@numba.jit(nopython=True, parallel=False)
def _numba_detect_peak_neg(
traces, traces_center, peak_mask, exclude_sweep_size, abs_thresholds, peak_sign, neighbours_mask
):
num_chans = traces_center.shape[1]
for chan_ind in range(num_chans):
for s in range(peak_mask.shape[0]):
if not peak_mask[s, chan_ind]:
continue
for neighbour in range(num_chans):
if not neighbours_mask[chan_ind, neighbour]:
continue
for i in range(exclude_sweep_size):
if chan_ind != neighbour:
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] <= traces_center[s, neighbour]
peak_mask[s, chan_ind] &= traces_center[s, chan_ind] < traces[s + i, neighbour]
peak_mask[s, chan_ind] &= (
traces_center[s, chan_ind] <= traces[exclude_sweep_size + s + i + 1, neighbour]
)
if not peak_mask[s, chan_ind]:
break
if not peak_mask[s, chan_ind]:
break
return peak_mask


class LocallyExclusiveTorchPeakDetector(ByChannelTorchPeakDetector):
Expand All @@ -205,7 +196,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
noise_levels=None,
device=None,
radius_um=50,
Expand Down Expand Up @@ -275,7 +266,7 @@ def __init__(
recording,
peak_sign="neg",
detect_threshold=5,
exclude_sweep_ms=0.1,
exclude_sweep_ms=1.0,
radius_um=50,
noise_levels=None,
opencl_context_kwargs={},
Expand Down
Loading
Loading