Skip to content

Commit 5a58beb

Browse files
ygerpre-commit-ci[bot]samuelgarciaalejoe91
authored
Improvements for internal sorters (#4279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Samuel Garcia <[email protected]> Co-authored-by: Alessio Buccino <[email protected]>
1 parent d7cfc07 commit 5a58beb

File tree

4 files changed

+111
-62
lines changed

4 files changed

+111
-62
lines changed

src/spikeinterface/sorters/internal/lupin.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class LupinSorter(ComponentsBasedSorter):
3535

3636
_default_params = {
3737
"apply_preprocessing": True,
38+
"preprocessing_dict": None,
3839
"apply_motion_correction": False,
3940
"motion_correction_preset": "dredge_fast",
4041
"clustering_ms_before": 0.3,
@@ -67,7 +68,8 @@ class LupinSorter(ComponentsBasedSorter):
6768

6869
_params_description = {
6970
"apply_preprocessing": "Apply internal preprocessing or not",
70-
"apply_motion_correction": "Apply motion correction or not",
71+
"preprocessing_dict": "Inject customized preprocessing chain via a dict, instead of the internal one",
72+
"apply_motion_correction": "Apply motion correction or not (only used when apply_preprocessing=True)",
7173
"motion_correction_preset": "Motion correction preset",
7274
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
7375
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
@@ -111,6 +113,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
111113
from spikeinterface.preprocessing import correct_motion
112114
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
113115
from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label
116+
from spikeinterface.preprocessing import apply_preprocessing_pipeline
114117

115118
job_kwargs = params["job_kwargs"].copy()
116119
job_kwargs = fix_job_kwargs(job_kwargs)
@@ -156,36 +159,42 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
156159
# preprocessing
157160
if params["apply_preprocessing"]:
158161
if params["apply_motion_correction"]:
162+
159163
rec_for_motion = recording_raw
160-
if params["apply_preprocessing"]:
164+
if params["preprocessing_dict"] is None:
161165
rec_for_motion = bandpass_filter(
162166
rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32"
163167
)
164168
if apply_cmr:
165169
rec_for_motion = common_reference(rec_for_motion)
166-
if verbose:
167-
print("Start correct_motion()")
168-
_, motion_info = correct_motion(
169-
rec_for_motion,
170-
folder=sorter_output_folder / "motion",
171-
output_motion_info=True,
172-
preset=params["motion_correction_preset"],
173-
)
174-
if verbose:
175-
print("Done correct_motion()")
176-
177-
recording = bandpass_filter(
178-
recording_raw,
179-
freq_min=params["freq_min"],
180-
freq_max=params["freq_max"],
181-
ftype="bessel",
182-
filter_order=2,
183-
margin_ms=20.0,
184-
dtype="float32",
185-
)
170+
else:
171+
rec_for_motion = apply_preprocessing_pipeline(rec_for_motion, params["preprocessing_dict"])
172+
173+
if verbose:
174+
print("Start correct_motion()")
175+
_, motion_info = correct_motion(
176+
rec_for_motion,
177+
folder=sorter_output_folder / "motion",
178+
output_motion_info=True,
179+
preset=params["motion_correction_preset"],
180+
)
181+
if verbose:
182+
print("Done correct_motion()")
183+
184+
if params["preprocessing_dict"] is None:
185+
recording = bandpass_filter(
186+
recording_raw,
187+
freq_min=params["freq_min"],
188+
freq_max=params["freq_max"],
189+
ftype="bessel",
190+
filter_order=2,
191+
dtype="float32",
192+
)
186193

187-
if apply_cmr:
188-
recording = common_reference(recording)
194+
if apply_cmr:
195+
recording = common_reference(recording)
196+
else:
197+
recording = apply_preprocessing_pipeline(recording, params["preprocessing_dict"])
189198

190199
recording = whiten(
191200
recording,
@@ -219,7 +228,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
219228

220229
noise_levels = get_noise_levels(recording, return_in_uV=False)
221230
else:
222-
recording = recording_raw
231+
recording = recording_raw.astype("float32")
223232
noise_levels = get_noise_levels(recording, return_in_uV=False)
224233
cache_info = None
225234

@@ -366,6 +375,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
366375
recording,
367376
sorting,
368377
templates,
378+
amplitude_scalings=spikes["amplitude"],
379+
noise_levels=noise_levels,
369380
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
370381
sparsity_overlap=0.5,
371382
censor_ms=3.0,

src/spikeinterface/sorters/internal/spyking_circus2.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
2424

2525
_default_params = {
2626
"general": {"ms_before": 0.5, "ms_after": 1.5, "radius_um": 100.0},
27-
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2, "margin_ms": 20},
27+
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
2828
"whitening": {"mode": "local", "regularize": False},
2929
"detection": {
3030
"method": "matched_filtering",
@@ -68,6 +68,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
6868
"motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)",
6969
"apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\
7070
median reference + whitening",
71+
"apply_whitening": "Boolean to specify whether circus 2 should whiten the recording or not",
7172
"apply_motion_correction": "Boolean to specify whether circus 2 should apply motion correction to the recording or not",
7273
"matched_filtering": "Boolean to specify whether circus 2 should detect peaks via matched filtering (slightly slower)",
7374
"cache_preprocessing": "How to cache the preprocessed recording. Mode can be memory, file, zarr, with extra arguments. In case of memory (default), \
@@ -145,7 +146,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
145146
print("Skipping preprocessing (whitening only)")
146147
else:
147148
print("Skipping preprocessing (no whitening)")
148-
recording_f = recording
149+
recording_f = recording.astype("float32")
149150
recording_f.annotate(is_filtered=True)
150151

151152
if apply_whitening:
@@ -189,9 +190,10 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
189190
)
190191

191192
if recording_w.check_serializability("json"):
192-
recording_w.dump(sorter_output_folder / "preprocessed_recording.json", relative_to=None)
193+
recording_dump_file = sorter_output_folder / "preprocessed_recording.json"
193194
elif recording_w.check_serializability("pickle"):
194-
recording_w.dump(sorter_output_folder / "preprocessed_recording.pickle", relative_to=None)
195+
recording_dump_file = sorter_output_folder / "preprocessed_recording.pickle"
196+
recording_w.dump(recording_dump_file, relative_to=None)
195197

196198
recording_w, cache_info = cache_preprocessing(
197199
recording_w, job_kwargs=job_kwargs, **params["cache_preprocessing"]
@@ -468,10 +470,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
468470
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)
469471

470472
if sorting.get_non_empty_unit_ids().size > 0:
473+
from spikeinterface.core import load
474+
475+
recording_ww = load(recording_dump_file)
476+
471477
final_analyzer = final_cleaning_circus(
472-
recording_w,
478+
recording_ww,
473479
sorting,
474480
templates,
481+
amplitude_scalings=spikes["amplitude"],
475482
noise_levels=noise_levels,
476483
job_kwargs=job_kwargs,
477484
**merging_params,
@@ -495,6 +502,7 @@ def final_cleaning_circus(
495502
recording,
496503
sorting,
497504
templates,
505+
amplitude_scalings=None,
498506
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
499507
sparsity_overlap=0.5,
500508
censor_ms=3.0,
@@ -509,7 +517,9 @@ def final_cleaning_circus(
509517
from spikeinterface.curation.auto_merge import auto_merge_units
510518

511519
# First we compute the needed extensions
512-
analyzer = create_sorting_analyzer_with_existing_templates(sorting, recording, templates, noise_levels=noise_levels)
520+
analyzer = create_sorting_analyzer_with_existing_templates(
521+
sorting, recording, templates, noise_levels=noise_levels, amplitude_scalings=amplitude_scalings
522+
)
513523
analyzer.compute("unit_locations", method="center_of_mass", **job_kwargs)
514524
analyzer.compute("template_similarity", **similarity_kwargs)
515525

src/spikeinterface/sorters/internal/tridesclous2.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
3131

3232
_default_params = {
3333
"apply_preprocessing": True,
34+
"preprocessing_dict": None,
3435
"apply_motion_correction": False,
3536
"motion_correction_preset": "dredge_fast",
3637
"clustering_ms_before": 0.5,
@@ -62,7 +63,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
6263

6364
_params_description = {
6465
"apply_preprocessing": "Apply internal preprocessing or not",
65-
"apply_motion_correction": "Apply motion correction or not",
66+
"preprocessing_dict": "Inject customized preprocessing chain via a dict, instead of the internal one",
67+
"apply_motion_correction": "Apply motion correction or not (only used when apply_preprocessing=True)",
6668
"motion_correction_preset": "Motion correction preset",
6769
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
6870
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
@@ -104,6 +106,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
104106
from spikeinterface.preprocessing import correct_motion
105107
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
106108
from spikeinterface.sortingcomponents.tools import clean_templates, compute_sparsity_from_peaks_and_label
109+
from spikeinterface.preprocessing import apply_preprocessing_pipeline
107110

108111
job_kwargs = params["job_kwargs"].copy()
109112
job_kwargs = fix_job_kwargs(job_kwargs)
@@ -121,38 +124,42 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
121124
# preprocessing
122125
if params["apply_preprocessing"]:
123126
if params["apply_motion_correction"]:
127+
124128
rec_for_motion = recording_raw
125-
if params["apply_preprocessing"]:
129+
if params["preprocessing_dict"] is None:
126130
rec_for_motion = bandpass_filter(
127131
rec_for_motion, freq_min=300.0, freq_max=6000.0, ftype="bessel", dtype="float32"
128132
)
129133
if apply_cmr:
130134
rec_for_motion = common_reference(rec_for_motion)
131-
if verbose:
132-
print("Start correct_motion()")
133-
_, motion_info = correct_motion(
134-
rec_for_motion,
135-
folder=sorter_output_folder / "motion",
136-
output_motion_info=True,
137-
preset=params["motion_correction_preset"],
138-
# **params["motion_correction"],
139-
)
140-
if verbose:
141-
print("Done correct_motion()")
142-
143-
# recording = bandpass_filter(recording_raw, **params["filtering"], margin_ms=20.0, dtype="float32")
144-
recording = bandpass_filter(
145-
recording_raw,
146-
freq_min=params["freq_min"],
147-
freq_max=params["freq_max"],
148-
ftype="bessel",
149-
filter_order=2,
150-
margin_ms=20.0,
151-
dtype="float32",
152-
)
135+
else:
136+
rec_for_motion = apply_preprocessing_pipeline(rec_for_motion, params["preprocessing_dict"])
137+
138+
if verbose:
139+
print("Start correct_motion()")
140+
_, motion_info = correct_motion(
141+
rec_for_motion,
142+
folder=sorter_output_folder / "motion",
143+
output_motion_info=True,
144+
preset=params["motion_correction_preset"],
145+
)
146+
if verbose:
147+
print("Done correct_motion()")
148+
149+
if params["preprocessing_dict"] is None:
150+
recording = bandpass_filter(
151+
recording_raw,
152+
freq_min=params["freq_min"],
153+
freq_max=params["freq_max"],
154+
ftype="bessel",
155+
filter_order=2,
156+
dtype="float32",
157+
)
153158

154-
if apply_cmr:
155-
recording = common_reference(recording)
159+
if apply_cmr:
160+
recording = common_reference(recording)
161+
else:
162+
recording = apply_preprocessing_pipeline(recording, params["preprocessing_dict"])
156163

157164
if params["apply_motion_correction"]:
158165
interpolate_motion_kwargs = dict(
@@ -183,7 +190,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
183190

184191
noise_levels = np.ones(num_chans, dtype="float32")
185192
else:
186-
recording = recording_raw
193+
recording = recording_raw.astype("float32")
187194
noise_levels = get_noise_levels(recording, return_in_uV=False)
188195
cache_info = None
189196

@@ -332,6 +339,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
332339
recording_for_peeler,
333340
sorting,
334341
templates,
342+
amplitude_scalings=spikes["amplitude"],
343+
noise_levels=noise_levels,
335344
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
336345
sparsity_overlap=0.5,
337346
censor_ms=3.0,

src/spikeinterface/sortingcomponents/tools.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212

1313
from spikeinterface.core.sparsity import ChannelSparsity
1414
from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer
15-
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
15+
from spikeinterface.core.job_tools import fix_job_kwargs
1616
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
1717
from spikeinterface.core.sparsity import ChannelSparsity
1818
from spikeinterface.core.sparsity import compute_sparsity
19-
from spikeinterface.core.analyzer_extension_core import ComputeTemplates, ComputeNoiseLevels
20-
from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift, get_template_extremum_channel
19+
from spikeinterface.core.template_tools import get_template_extremum_channel_peak_shift
2120
from spikeinterface.core.recording_tools import get_noise_levels
2221
from spikeinterface.core.sorting_tools import get_numba_vector_to_list_of_spiketrain
2322

@@ -480,7 +479,7 @@ def remove_empty_templates(templates):
480479

481480

482481
def create_sorting_analyzer_with_existing_templates(
483-
sorting, recording, templates, remove_empty=True, noise_levels=None
482+
sorting, recording, templates, remove_empty=True, noise_levels=None, amplitude_scalings=None
484483
):
485484
sparsity = templates.sparsity
486485
templates_array = templates.get_dense_templates().copy()
@@ -495,6 +494,8 @@ def create_sorting_analyzer_with_existing_templates(
495494
else:
496495
non_empty_sorting = sorting
497496

497+
from spikeinterface.core.analyzer_extension_core import ComputeTemplates
498+
498499
sa = create_sorting_analyzer(non_empty_sorting, recording, format="memory", sparsity=sparsity)
499500
sa.compute("random_spikes")
500501
sa.extensions["templates"] = ComputeTemplates(sa)
@@ -509,12 +510,30 @@ def create_sorting_analyzer_with_existing_templates(
509510
sa.extensions["templates"].run_info["runtime_s"] = 0
510511

511512
if noise_levels is not None:
513+
from spikeinterface.core.analyzer_extension_core import ComputeNoiseLevels
514+
512515
sa.extensions["noise_levels"] = ComputeNoiseLevels(sa)
513516
sa.extensions["noise_levels"].params = {}
514517
sa.extensions["noise_levels"].data["noise_levels"] = noise_levels
515518
sa.extensions["noise_levels"].run_info["run_completed"] = True
516519
sa.extensions["noise_levels"].run_info["runtime_s"] = 0
517520

521+
if amplitude_scalings is not None:
522+
from spikeinterface.postprocessing.amplitude_scalings import ComputeAmplitudeScalings
523+
524+
sa.extensions["amplitude_scalings"] = ComputeAmplitudeScalings(sa)
525+
sa.extensions["amplitude_scalings"].params = dict(
526+
sparsity=None,
527+
max_dense_channels=16,
528+
ms_before=templates.ms_before,
529+
ms_after=templates.ms_after,
530+
handle_collisions=False,
531+
delta_collision_ms=2,
532+
)
533+
sa.extensions["amplitude_scalings"].data["amplitude_scalings"] = amplitude_scalings
534+
sa.extensions["amplitude_scalings"].run_info["run_completed"] = True
535+
sa.extensions["amplitude_scalings"].run_info["runtime_s"] = 0
536+
518537
return sa
519538

520539

0 commit comments

Comments
 (0)