Skip to content

Commit a12830e

Browse files
Update nearest peeler (#4223)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5dfb788 commit a12830e

File tree

1 file changed

+88
-33
lines changed
  • src/spikeinterface/sortingcomponents/matching

1 file changed

+88
-33
lines changed

src/spikeinterface/sortingcomponents/matching/nearest.py

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,63 @@ def __init__(
3636
exclude_sweep_ms=0.1,
3737
detect_threshold=5,
3838
noise_levels=None,
39-
radius_um=100.0,
39+
detection_radius_um=100.0,
40+
neighborhood_radius_um=50.0,
41+
sparsity_radius_um=100.0,
4042
):
4143

4244
BaseTemplateMatching.__init__(self, recording, templates, return_output=return_output)
4345

44-
self.templates_array = self.templates.get_dense_templates()
45-
4646
self.noise_levels = noise_levels
4747
self.abs_threholds = self.noise_levels * detect_threshold
4848
self.peak_sign = peak_sign
49-
channel_distance = get_channel_distances(recording)
50-
self.neighbours_mask = channel_distance <= radius_um
49+
self.channel_distance = get_channel_distances(recording)
50+
self.neighbours_mask = self.channel_distance <= detection_radius_um
51+
52+
num_templates = len(self.templates.unit_ids)
53+
num_channels = recording.get_num_channels()
54+
55+
if neighborhood_radius_um is not None:
56+
from spikeinterface.core.template_tools import get_template_extremum_channel
57+
58+
best_channels = get_template_extremum_channel(self.templates, peak_sign=self.peak_sign, outputs="index")
59+
best_channels = np.array([best_channels[i] for i in templates.unit_ids])
60+
channel_locations = recording.get_channel_locations()
61+
template_distances = np.linalg.norm(
62+
channel_locations[:, None] - channel_locations[best_channels][np.newaxis, :], axis=2
63+
)
64+
self.neighborhood_mask = template_distances <= neighborhood_radius_um
65+
else:
66+
self.neighborhood_mask = np.ones((num_channels, num_templates), dtype=bool)
67+
68+
if sparsity_radius_um is not None:
69+
if not templates.are_templates_sparse():
70+
from spikeinterface.core.sparsity import compute_sparsity
71+
72+
sparsity = compute_sparsity(
73+
templates, method="radius", radius_um=sparsity_radius_um, peak_sign=self.peak_sign
74+
)
75+
else:
76+
sparsity = templates.sparsity
77+
78+
self.sparsity_mask = np.zeros((num_channels, num_channels), dtype=bool)
79+
for channel_index in np.arange(num_channels):
80+
mask = self.neighborhood_mask[channel_index]
81+
self.sparsity_mask[channel_index] = np.sum(sparsity.mask[mask], axis=0) > 0
82+
else:
83+
self.sparsity_mask = np.ones((num_channels, num_channels), dtype=bool)
84+
85+
self.templates_array = self.templates.get_dense_templates()
5186
self.exclude_sweep_size = int(exclude_sweep_ms * recording.get_sampling_frequency() / 1000.0)
5287
self.nbefore = self.templates.nbefore
5388
self.nafter = self.templates.nafter
5489
self.margin = max(self.nbefore, self.nafter)
90+
self.lookup_tables = {}
91+
self.lookup_tables["templates"] = {}
92+
self.lookup_tables["channels"] = {}
93+
for i in range(num_channels):
94+
self.lookup_tables["templates"][i] = np.flatnonzero(self.neighborhood_mask[i])
95+
self.lookup_tables["channels"][i] = np.flatnonzero(self.sparsity_mask[i])
5596

5697
def get_trace_margin(self):
5798
return self.margin
@@ -76,17 +117,24 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
76117
spikes["channel_index"] = peak_chan_ind
77118
spikes["amplitude"] = 1.0
78119

79-
waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)]
80-
num_templates = len(self.templates_array)
81-
XA = self.templates_array.reshape(num_templates, -1)
82-
83120
# naively take the closest template
84121
for main_chan in np.unique(spikes["channel_index"]):
85122
(idx,) = np.nonzero(spikes["channel_index"] == main_chan)
86-
XB = waveforms[idx].reshape(len(idx), -1)
87-
dist = cdist(XA, XB, "euclidean")
88-
cluster_index = np.argmin(dist, 0)
89-
spikes["cluster_index"][idx] = cluster_index
123+
124+
unit_inds = self.lookup_tables["templates"][main_chan]
125+
templates = self.templates_array[unit_inds]
126+
num_templates = templates.shape[0]
127+
if num_templates > 0:
128+
waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)]
129+
chan_inds = self.lookup_tables["channels"][main_chan]
130+
XA = templates[:, :, chan_inds].reshape(num_templates, -1)
131+
XB = waveforms[:, :, chan_inds].reshape(len(idx), -1)
132+
133+
dist = cdist(XA, XB, "euclidean")
134+
cluster_index = np.argmin(dist, 0)
135+
spikes["cluster_index"][idx] = unit_inds[cluster_index]
136+
else:
137+
spikes["cluster_index"][idx] = -1 # no template for this channel
90138

91139
return spikes
92140

@@ -111,13 +159,14 @@ def __init__(
111159
recording,
112160
templates,
113161
svd_model,
114-
svd_radius_um=100,
115162
return_output=True,
116163
peak_sign="neg",
117164
exclude_sweep_ms=0.1,
118165
detect_threshold=5,
119166
noise_levels=None,
120-
radius_um=100.0,
167+
detection_radius_um=100.0,
168+
neighborhood_radius_um=50.0,
169+
sparsity_radius_um=100.0,
121170
):
122171

123172
NearestTemplatesPeeler.__init__(
@@ -129,7 +178,9 @@ def __init__(
129178
exclude_sweep_ms=exclude_sweep_ms,
130179
detect_threshold=detect_threshold,
131180
noise_levels=noise_levels,
132-
radius_um=radius_um,
181+
detection_radius_um=detection_radius_um,
182+
neighborhood_radius_um=neighborhood_radius_um,
183+
sparsity_radius_um=sparsity_radius_um,
133184
)
134185

135186
from spikeinterface.sortingcomponents.waveforms.waveform_utils import (
@@ -139,10 +190,6 @@ def __init__(
139190

140191
self.num_channels = self.recording.get_num_channels()
141192
self.svd_model = svd_model
142-
self.svd_radius_um = svd_radius_um
143-
channel_distance = get_channel_distances(recording)
144-
self.svd_neighbours_mask = channel_distance <= self.svd_radius_um
145-
146193
temporal_templates = to_temporal_representation(self.templates_array)
147194
projected_temporal_templates = self.svd_model.transform(temporal_templates)
148195
self.svd_templates = from_temporal_representation(projected_temporal_templates, self.num_channels)
@@ -175,20 +222,28 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
175222
spikes["channel_index"] = peak_chan_ind
176223
spikes["amplitude"] = 1.0
177224

178-
waveforms = traces[spikes["sample_index"][:, None] + np.arange(-self.nbefore, self.nafter)]
179-
num_templates = len(self.templates_array)
180-
181-
temporal_waveforms = to_temporal_representation(waveforms)
182-
projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms)
183-
projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels)
184-
225+
# naively take the closest template
185226
for main_chan in np.unique(spikes["channel_index"]):
186227
(idx,) = np.nonzero(spikes["channel_index"] == main_chan)
187-
(chan_inds,) = np.nonzero(self.svd_neighbours_mask[main_chan])
188-
local_svds = projected_waveforms[idx][:, :, chan_inds]
189-
XA = local_svds.reshape(len(idx), -1)
190-
XB = self.svd_templates[:, :, chan_inds].reshape(num_templates, -1)
191-
distances = cdist(XA, XB, metric="euclidean")
192-
spikes["cluster_index"][idx] = np.argmin(distances, axis=1)
228+
229+
unit_inds = self.lookup_tables["templates"][main_chan]
230+
templates = self.svd_templates[unit_inds]
231+
num_templates = templates.shape[0]
232+
233+
if num_templates > 0:
234+
chan_inds = self.lookup_tables["channels"][main_chan]
235+
waveforms = traces[spikes["sample_index"][idx][:, None] + np.arange(-self.nbefore, self.nafter)]
236+
temporal_waveforms = to_temporal_representation(waveforms)
237+
projected_temporal_waveforms = self.svd_model.transform(temporal_waveforms)
238+
projected_waveforms = from_temporal_representation(projected_temporal_waveforms, self.num_channels)
239+
240+
XA = templates[:, :, chan_inds].reshape(num_templates, -1)
241+
XB = projected_waveforms[:, :, chan_inds].reshape(len(idx), -1)
242+
243+
dist = cdist(XA, XB, "euclidean")
244+
cluster_index = np.argmin(dist, 0)
245+
spikes["cluster_index"][idx] = unit_inds[cluster_index]
246+
else:
247+
spikes["cluster_index"][idx] = -1 # no template for this channel
193248

194249
return spikes

0 commit comments

Comments
 (0)