@@ -36,22 +36,63 @@ def __init__(
3636 exclude_sweep_ms = 0.1 ,
3737 detect_threshold = 5 ,
3838 noise_levels = None ,
39- radius_um = 100.0 ,
39+ detection_radius_um = 100.0 ,
40+ neighborhood_radius_um = 50.0 ,
41+ sparsity_radius_um = 100.0 ,
4042 ):
4143
4244 BaseTemplateMatching .__init__ (self , recording , templates , return_output = return_output )
4345
44- self .templates_array = self .templates .get_dense_templates ()
45-
4646 self .noise_levels = noise_levels
4747 self .abs_threholds = self .noise_levels * detect_threshold
4848 self .peak_sign = peak_sign
49- channel_distance = get_channel_distances (recording )
50- self .neighbours_mask = channel_distance <= radius_um
49+ self .channel_distance = get_channel_distances (recording )
50+ self .neighbours_mask = self .channel_distance <= detection_radius_um
51+
52+ num_templates = len (self .templates .unit_ids )
53+ num_channels = recording .get_num_channels ()
54+
55+ if neighborhood_radius_um is not None :
56+ from spikeinterface .core .template_tools import get_template_extremum_channel
57+
58+ best_channels = get_template_extremum_channel (self .templates , peak_sign = self .peak_sign , outputs = "index" )
59+ best_channels = np .array ([best_channels [i ] for i in templates .unit_ids ])
60+ channel_locations = recording .get_channel_locations ()
61+ template_distances = np .linalg .norm (
62+ channel_locations [:, None ] - channel_locations [best_channels ][np .newaxis , :], axis = 2
63+ )
64+ self .neighborhood_mask = template_distances <= neighborhood_radius_um
65+ else :
66+ self .neighborhood_mask = np .ones ((num_channels , num_templates ), dtype = bool )
67+
68+ if sparsity_radius_um is not None :
69+ if not templates .are_templates_sparse ():
70+ from spikeinterface .core .sparsity import compute_sparsity
71+
72+ sparsity = compute_sparsity (
73+ templates , method = "radius" , radius_um = sparsity_radius_um , peak_sign = self .peak_sign
74+ )
75+ else :
76+ sparsity = templates .sparsity
77+
78+ self .sparsity_mask = np .zeros ((num_channels , num_channels ), dtype = bool )
79+ for channel_index in np .arange (num_channels ):
80+ mask = self .neighborhood_mask [channel_index ]
81+ self .sparsity_mask [channel_index ] = np .sum (sparsity .mask [mask ], axis = 0 ) > 0
82+ else :
83+ self .sparsity_mask = np .ones ((num_channels , num_channels ), dtype = bool )
84+
85+ self .templates_array = self .templates .get_dense_templates ()
5186 self .exclude_sweep_size = int (exclude_sweep_ms * recording .get_sampling_frequency () / 1000.0 )
5287 self .nbefore = self .templates .nbefore
5388 self .nafter = self .templates .nafter
5489 self .margin = max (self .nbefore , self .nafter )
90+ self .lookup_tables = {}
91+ self .lookup_tables ["templates" ] = {}
92+ self .lookup_tables ["channels" ] = {}
93+ for i in range (num_channels ):
94+ self .lookup_tables ["templates" ][i ] = np .flatnonzero (self .neighborhood_mask [i ])
95+ self .lookup_tables ["channels" ][i ] = np .flatnonzero (self .sparsity_mask [i ])
5596
5697 def get_trace_margin (self ):
5798 return self .margin
@@ -76,17 +117,24 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
76117 spikes ["channel_index" ] = peak_chan_ind
77118 spikes ["amplitude" ] = 1.0
78119
79- waveforms = traces [spikes ["sample_index" ][:, None ] + np .arange (- self .nbefore , self .nafter )]
80- num_templates = len (self .templates_array )
81- XA = self .templates_array .reshape (num_templates , - 1 )
82-
83120 # naively take the closest template
84121 for main_chan in np .unique (spikes ["channel_index" ]):
85122 (idx ,) = np .nonzero (spikes ["channel_index" ] == main_chan )
86- XB = waveforms [idx ].reshape (len (idx ), - 1 )
87- dist = cdist (XA , XB , "euclidean" )
88- cluster_index = np .argmin (dist , 0 )
89- spikes ["cluster_index" ][idx ] = cluster_index
123+
124+ unit_inds = self .lookup_tables ["templates" ][main_chan ]
125+ templates = self .templates_array [unit_inds ]
126+ num_templates = templates .shape [0 ]
127+ if num_templates > 0 :
128+ waveforms = traces [spikes ["sample_index" ][idx ][:, None ] + np .arange (- self .nbefore , self .nafter )]
129+ chan_inds = self .lookup_tables ["channels" ][main_chan ]
130+ XA = templates [:, :, chan_inds ].reshape (num_templates , - 1 )
131+ XB = waveforms [:, :, chan_inds ].reshape (len (idx ), - 1 )
132+
133+ dist = cdist (XA , XB , "euclidean" )
134+ cluster_index = np .argmin (dist , 0 )
135+ spikes ["cluster_index" ][idx ] = unit_inds [cluster_index ]
136+ else :
137+ spikes ["cluster_index" ][idx ] = - 1 # no template for this channel
90138
91139 return spikes
92140
@@ -111,13 +159,14 @@ def __init__(
111159 recording ,
112160 templates ,
113161 svd_model ,
114- svd_radius_um = 100 ,
115162 return_output = True ,
116163 peak_sign = "neg" ,
117164 exclude_sweep_ms = 0.1 ,
118165 detect_threshold = 5 ,
119166 noise_levels = None ,
120- radius_um = 100.0 ,
167+ detection_radius_um = 100.0 ,
168+ neighborhood_radius_um = 50.0 ,
169+ sparsity_radius_um = 100.0 ,
121170 ):
122171
123172 NearestTemplatesPeeler .__init__ (
@@ -129,7 +178,9 @@ def __init__(
129178 exclude_sweep_ms = exclude_sweep_ms ,
130179 detect_threshold = detect_threshold ,
131180 noise_levels = noise_levels ,
132- radius_um = radius_um ,
181+ detection_radius_um = detection_radius_um ,
182+ neighborhood_radius_um = neighborhood_radius_um ,
183+ sparsity_radius_um = sparsity_radius_um ,
133184 )
134185
135186 from spikeinterface .sortingcomponents .waveforms .waveform_utils import (
@@ -139,10 +190,6 @@ def __init__(
139190
140191 self .num_channels = self .recording .get_num_channels ()
141192 self .svd_model = svd_model
142- self .svd_radius_um = svd_radius_um
143- channel_distance = get_channel_distances (recording )
144- self .svd_neighbours_mask = channel_distance <= self .svd_radius_um
145-
146193 temporal_templates = to_temporal_representation (self .templates_array )
147194 projected_temporal_templates = self .svd_model .transform (temporal_templates )
148195 self .svd_templates = from_temporal_representation (projected_temporal_templates , self .num_channels )
@@ -175,20 +222,28 @@ def compute_matching(self, traces, start_frame, end_frame, segment_index):
175222 spikes ["channel_index" ] = peak_chan_ind
176223 spikes ["amplitude" ] = 1.0
177224
178- waveforms = traces [spikes ["sample_index" ][:, None ] + np .arange (- self .nbefore , self .nafter )]
179- num_templates = len (self .templates_array )
180-
181- temporal_waveforms = to_temporal_representation (waveforms )
182- projected_temporal_waveforms = self .svd_model .transform (temporal_waveforms )
183- projected_waveforms = from_temporal_representation (projected_temporal_waveforms , self .num_channels )
184-
225+ # naively take the closest template
185226 for main_chan in np .unique (spikes ["channel_index" ]):
186227 (idx ,) = np .nonzero (spikes ["channel_index" ] == main_chan )
187- (chan_inds ,) = np .nonzero (self .svd_neighbours_mask [main_chan ])
188- local_svds = projected_waveforms [idx ][:, :, chan_inds ]
189- XA = local_svds .reshape (len (idx ), - 1 )
190- XB = self .svd_templates [:, :, chan_inds ].reshape (num_templates , - 1 )
191- distances = cdist (XA , XB , metric = "euclidean" )
192- spikes ["cluster_index" ][idx ] = np .argmin (distances , axis = 1 )
228+
229+ unit_inds = self .lookup_tables ["templates" ][main_chan ]
230+ templates = self .svd_templates [unit_inds ]
231+ num_templates = templates .shape [0 ]
232+
233+ if num_templates > 0 :
234+ chan_inds = self .lookup_tables ["channels" ][main_chan ]
235+ waveforms = traces [spikes ["sample_index" ][idx ][:, None ] + np .arange (- self .nbefore , self .nafter )]
236+ temporal_waveforms = to_temporal_representation (waveforms )
237+ projected_temporal_waveforms = self .svd_model .transform (temporal_waveforms )
238+ projected_waveforms = from_temporal_representation (projected_temporal_waveforms , self .num_channels )
239+
240+ XA = templates [:, :, chan_inds ].reshape (num_templates , - 1 )
241+ XB = projected_waveforms [:, :, chan_inds ].reshape (len (idx ), - 1 )
242+
243+ dist = cdist (XA , XB , "euclidean" )
244+ cluster_index = np .argmin (dist , 0 )
245+ spikes ["cluster_index" ][idx ] = unit_inds [cluster_index ]
246+ else :
247+ spikes ["cluster_index" ][idx ] = - 1 # no template for this channel
193248
194249 return spikes
0 commit comments