Skip to content

Commit 5dfb788

Browse files
authored
A new sorter : Lupin sorter (#4192)
1 parent 7def9dc commit 5dfb788

File tree

11 files changed

+524
-143
lines changed

11 files changed

+524
-143
lines changed

src/spikeinterface/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
)
112112
from .sorting_tools import (
113113
spike_vector_to_spike_trains,
114+
spike_vector_to_indices,
114115
random_spikes_selection,
115116
apply_merges_to_sorting,
116117
apply_splits_to_sorting,
Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
from __future__ import annotations
2+
3+
from .si_based import ComponentsBasedSorter
4+
5+
from copy import deepcopy
6+
7+
from spikeinterface.core import (
8+
get_noise_levels,
9+
NumpySorting,
10+
estimate_templates_with_accumulator,
11+
Templates,
12+
compute_sparsity,
13+
)
14+
15+
from spikeinterface.core.job_tools import fix_job_kwargs
16+
17+
from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
18+
from spikeinterface.core.basesorting import minimum_spike_dtype
19+
20+
from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing
21+
22+
23+
import numpy as np
24+
25+
26+
class LupinSorter(ComponentsBasedSorter):
27+
"""
28+
Gentleman thief spike sorter.
29+
30+
This sorter is composed by pieces of code and ideas stolen everywhere : yass, tridesclous, spkyking-circus, kilosort.
31+
It should be the best sorter we can build using spikeinterface.sortingcomponents
32+
"""
33+
34+
sorter_name = "lupin"
35+
36+
_default_params = {
37+
"apply_preprocessing": True,
38+
"apply_motion_correction": False,
39+
"motion_correction_preset": "dredge_fast",
40+
"clustering_ms_before": 0.3,
41+
"clustering_ms_after": 1.3,
42+
"whitening_radius_um": 100.0,
43+
"detection_radius_um": 50.0,
44+
"features_radius_um": 75.0,
45+
"template_radius_um": 100.0,
46+
"freq_min": 150.0,
47+
"freq_max": 7000.0,
48+
"cache_preprocessing_mode": "auto",
49+
"peak_sign": "neg",
50+
"detect_threshold": 5,
51+
"n_peaks_per_channel": 5000,
52+
"n_svd_components_per_channel": 5,
53+
"n_pca_features": 4,
54+
"clustering_recursive_depth": 3,
55+
"ms_before": 1.0,
56+
"ms_after": 2.5,
57+
"sparsity_threshold": 1.5,
58+
"template_min_snr": 2.5,
59+
"gather_mode": "memory",
60+
"job_kwargs": {},
61+
"seed": None,
62+
"save_array": True,
63+
"debug": False,
64+
}
65+
66+
_params_description = {
67+
"apply_preprocessing": "Apply internal preprocessing or not",
68+
"apply_motion_correction": "Apply motion correction or not",
69+
"motion_correction_preset": "Motion correction preset",
70+
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
71+
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
72+
"radius_um": "Radius for sparsity",
73+
"freq_min": "Low frequency",
74+
"freq_max": "High frequency",
75+
"peak_sign": "Sign of peaks neg/pos/both",
76+
"detect_threshold": "Treshold for peak detection",
77+
"n_peaks_per_channel": "Number of spike per channel for clustering",
78+
"n_svd_components_per_channel": "Number of SVD components per channel for clustering",
79+
"n_pca_features": "Secondary PCA features reducation before local isosplit",
80+
"clustering_recursive_depth": "Clustering recussivity",
81+
"ms_before": "Milliseconds before the spike peak for template matching",
82+
"ms_after": "Milliseconds after the spike peak for template matching",
83+
"sparsity_threshold": "Threshold to sparsify templates before template matching",
84+
"template_min_snr": "Threshold to remove templates before template matching",
85+
"gather_mode": "How to accumalte spike in matching : memory/npy",
86+
"job_kwargs": "The famous and fabulous job_kwargs",
87+
"seed": "Seed for random number",
88+
"save_array": "Save or not intermediate arrays in the folder",
89+
"debug": "Save debug files",
90+
}
91+
92+
handle_multi_segment = True
93+
94+
@classmethod
95+
def get_sorter_version(cls):
96+
return "2025.11"
97+
98+
@classmethod
99+
def _run_from_folder(cls, sorter_output_folder, params, verbose):
100+
101+
from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_recording
102+
from spikeinterface.sortingcomponents.matching import find_spikes_from_templates
103+
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
104+
from spikeinterface.sortingcomponents.peak_selection import select_peaks
105+
from spikeinterface.sortingcomponents.clustering.main import find_clusters_from_peaks, clustering_methods
106+
from spikeinterface.sortingcomponents.tools import remove_empty_templates
107+
from spikeinterface.preprocessing import correct_motion
108+
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
109+
from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label
110+
111+
job_kwargs = params["job_kwargs"].copy()
112+
job_kwargs = fix_job_kwargs(job_kwargs)
113+
job_kwargs["progress_bar"] = verbose
114+
115+
seed = params["seed"]
116+
117+
recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False)
118+
119+
num_chans = recording_raw.get_num_channels()
120+
sampling_frequency = recording_raw.get_sampling_frequency()
121+
122+
apply_cmr = num_chans >= 32
123+
124+
# preprocessing
125+
if params["apply_preprocessing"]:
126+
if params["apply_motion_correction"]:
127+
rec_for_motion = recording_raw
128+
if params["apply_preprocessing"]:
129+
rec_for_motion = bandpass_filter(
130+
rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32"
131+
)
132+
if apply_cmr:
133+
rec_for_motion = common_reference(rec_for_motion)
134+
if verbose:
135+
print("Start correct_motion()")
136+
_, motion_info = correct_motion(
137+
rec_for_motion,
138+
folder=sorter_output_folder / "motion",
139+
output_motion_info=True,
140+
preset=params["motion_correction_preset"],
141+
)
142+
if verbose:
143+
print("Done correct_motion()")
144+
145+
recording = bandpass_filter(
146+
recording_raw,
147+
freq_min=params["freq_min"],
148+
freq_max=params["freq_max"],
149+
ftype="bessel",
150+
filter_order=2,
151+
margin_ms=20.0,
152+
dtype="float32",
153+
)
154+
155+
if apply_cmr:
156+
recording = common_reference(recording)
157+
158+
recording = whiten(
159+
recording,
160+
dtype="float32",
161+
mode="local",
162+
radius_um=params["whitening_radius_um"],
163+
)
164+
165+
if params["apply_motion_correction"]:
166+
interpolate_motion_kwargs = dict(
167+
border_mode="force_extrapolate",
168+
spatial_interpolation_method="kriging",
169+
sigma_um=20.0,
170+
p=2,
171+
)
172+
173+
recording = InterpolateMotionRecording(
174+
recording,
175+
motion_info["motion"],
176+
**interpolate_motion_kwargs,
177+
)
178+
179+
# Cache in mem or folder
180+
cache_folder = sorter_output_folder / "cache_preprocessing"
181+
recording, cache_info = cache_preprocessing(
182+
recording,
183+
mode=params["cache_preprocessing_mode"],
184+
folder=cache_folder,
185+
job_kwargs=job_kwargs,
186+
)
187+
188+
noise_levels = get_noise_levels(recording, return_in_uV=False)
189+
else:
190+
recording = recording_raw
191+
noise_levels = get_noise_levels(recording, return_in_uV=False)
192+
cache_info = None
193+
194+
# detection
195+
ms_before = params["ms_before"]
196+
ms_after = params["ms_after"]
197+
prototype, few_waveforms, few_peaks = get_prototype_and_waveforms_from_recording(
198+
recording,
199+
n_peaks=10_000,
200+
ms_before=ms_before,
201+
ms_after=ms_after,
202+
seed=seed,
203+
noise_levels=noise_levels,
204+
job_kwargs=job_kwargs,
205+
)
206+
detection_params = dict(
207+
peak_sign=params["peak_sign"],
208+
detect_threshold=params["detect_threshold"],
209+
exclude_sweep_ms=1.5,
210+
radius_um=params["detection_radius_um"],
211+
prototype=prototype,
212+
ms_before=ms_before,
213+
)
214+
all_peaks = detect_peaks(
215+
recording, method="matched_filtering", method_kwargs=detection_params, job_kwargs=job_kwargs
216+
)
217+
218+
if verbose:
219+
print(f"detect_peaks(): {len(all_peaks)} peaks found")
220+
221+
# selection
222+
n_peaks = max(params["n_peaks_per_channel"] * num_chans, 20_000)
223+
peaks = select_peaks(all_peaks, method="uniform", n_peaks=n_peaks)
224+
if verbose:
225+
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")
226+
227+
# Clustering
228+
clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
229+
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
230+
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
231+
clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"]
232+
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"]
233+
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
234+
clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"]
235+
236+
if params["debug"]:
237+
clustering_kwargs["debug_folder"] = sorter_output_folder
238+
unit_ids, clustering_label, more_outs = find_clusters_from_peaks(
239+
recording,
240+
peaks,
241+
method="iterative-isosplit",
242+
method_kwargs=clustering_kwargs,
243+
extra_outputs=True,
244+
job_kwargs=job_kwargs,
245+
)
246+
247+
mask = clustering_label >= 0
248+
kept_peaks = peaks[mask]
249+
kept_labels = clustering_label[mask]
250+
251+
sorting_pre_peeler = NumpySorting.from_samples_and_labels(
252+
kept_peaks["sample_index"],
253+
kept_labels,
254+
sampling_frequency,
255+
unit_ids=unit_ids,
256+
)
257+
if verbose:
258+
print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found")
259+
260+
# preestimate the sparsity unsing peaks channel
261+
spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True)
262+
sparsity, unit_locations = compute_sparsity_from_peaks_and_label(
263+
kept_peaks, spike_vector["unit_index"], sorting_pre_peeler.unit_ids, recording, params["template_radius_um"]
264+
)
265+
266+
# Template are sparse from radius using unit_location
267+
nbefore = int(ms_before * sampling_frequency / 1000.0)
268+
nafter = int(ms_after * sampling_frequency / 1000.0)
269+
templates_array = estimate_templates_with_accumulator(
270+
recording,
271+
sorting_pre_peeler.to_spike_vector(),
272+
sorting_pre_peeler.unit_ids,
273+
nbefore,
274+
nafter,
275+
return_in_uV=False,
276+
sparsity_mask=sparsity.mask,
277+
**job_kwargs,
278+
)
279+
templates = Templates(
280+
templates_array=templates_array,
281+
sampling_frequency=sampling_frequency,
282+
nbefore=nbefore,
283+
channel_ids=recording.channel_ids,
284+
unit_ids=sorting_pre_peeler.unit_ids,
285+
sparsity_mask=sparsity.mask,
286+
probe=recording.get_probe(),
287+
is_in_uV=False,
288+
)
289+
290+
# this spasify more
291+
templates = clean_templates(
292+
templates,
293+
sparsify_threshold=params["sparsity_threshold"],
294+
noise_levels=noise_levels,
295+
min_snr=params["template_min_snr"],
296+
max_jitter_ms=None,
297+
remove_empty=True,
298+
)
299+
300+
# Template matching
301+
gather_mode = params["gather_mode"]
302+
pipeline_kwargs = dict(gather_mode=gather_mode)
303+
if gather_mode == "npy":
304+
pipeline_kwargs["folder"] = sorter_output_folder / "matching"
305+
306+
spikes = find_spikes_from_templates(
307+
recording,
308+
templates,
309+
method="wobble",
310+
method_kwargs={},
311+
pipeline_kwargs=pipeline_kwargs,
312+
job_kwargs=job_kwargs,
313+
)
314+
315+
final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
316+
final_spikes["sample_index"] = spikes["sample_index"]
317+
final_spikes["unit_index"] = spikes["cluster_index"]
318+
final_spikes["segment_index"] = spikes["segment_index"]
319+
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)
320+
321+
auto_merge = True
322+
analyzer_final = None
323+
if auto_merge:
324+
# TODO expose some of theses parameters
325+
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus
326+
327+
analyzer_final = final_cleaning_circus(
328+
recording,
329+
sorting,
330+
templates,
331+
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
332+
sparsity_overlap=0.5,
333+
censor_ms=3.0,
334+
max_distance_um=50,
335+
template_diff_thresh=np.arange(0.05, 0.4, 0.05),
336+
debug_folder=None,
337+
job_kwargs=job_kwargs,
338+
)
339+
sorting = NumpySorting.from_sorting(analyzer_final.sorting)
340+
341+
if params["save_array"]:
342+
sorting_pre_peeler = sorting_pre_peeler.save(folder=sorter_output_folder / "sorting_pre_peeler")
343+
np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
344+
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
345+
np.save(sorter_output_folder / "peaks.npy", peaks)
346+
np.save(sorter_output_folder / "clustering_label.npy", clustering_label)
347+
np.save(sorter_output_folder / "spikes.npy", spikes)
348+
templates.to_zarr(sorter_output_folder / "templates.zarr")
349+
if analyzer_final is not None:
350+
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")
351+
352+
sorting = sorting.save(folder=sorter_output_folder / "sorting")
353+
354+
del recording
355+
clean_cache_preprocessing(cache_info)
356+
357+
return sorting

0 commit comments

Comments
 (0)