Skip to content
7 changes: 4 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,9 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer,
arr = np.std(wfs, axis=0)
elif operator == "median":
arr = np.median(wfs, axis=0)
elif "percentile" in operator:
_, percentile = operator.splot("_")
# there was a spelling error "pencentile" in old versions
elif "percentile" in operator or "pencentile" in operator:
_, percentile = operator.split("_")
arr = np.percentile(wfs, float(percentile), axis=0)
new_array[split_unit_index, ...] = arr
else:
Expand Down Expand Up @@ -824,7 +825,7 @@ class BaseMetric:
metric_descriptions = {} # descriptions of each metric column
needs_recording = False # whether the metric needs recording
needs_tmp_data = (
False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level
False # whether the metric needs temporary data computed with _prepare_data at the MetricExtension level
)
needs_job_kwargs = False
depend_on = [] # extensions the metric depends on
Expand Down
1 change: 1 addition & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,7 @@ class AnalyzerExtension:
* _run()
* _select_extension_data()
* _merge_extension_data()
* _split_extension_data()
* _get_data()

The subclass must also set an `extension_name` class attribute which is not None by default.
Expand Down
32 changes: 26 additions & 6 deletions src/spikeinterface/curation/curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from itertools import chain

from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting
from spikeinterface.curation.curation_model import CurationModel
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration


def validate_curation_dict(curation_dict: dict):
Expand Down Expand Up @@ -138,7 +138,7 @@ def apply_curation_labels(

def apply_curation(
sorting_or_analyzer: BaseSorting | SortingAnalyzer,
curation_dict_or_model: dict | CurationModel,
curation_dict_or_model: dict | list | CurationModel | SequentialCuration,
censor_ms: float | None = None,
new_id_strategy: str = "append",
merging_mode: str = "soft",
Expand All @@ -164,7 +164,7 @@ def apply_curation(
----------
sorting_or_analyzer : Sorting | SortingAnalyzer
The Sorting or SortingAnalyzer object to apply merges.
curation_dict : dict or CurationModel
curation_dict : dict | CurationModel | SequentialCuration
The curation dict or model.
censor_ms : float | None, default: None
When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of
Expand Down Expand Up @@ -199,14 +199,32 @@ def apply_curation(
sorting_or_analyzer, (BaseSorting, SortingAnalyzer)
), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}"
assert isinstance(
curation_dict_or_model, (dict, CurationModel)
), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type {type(curation_dict_or_model)}"
curation_dict_or_model, (dict, list, CurationModel, SequentialCuration)
), f"`curation_dict_or_model` must be a dict, CurationModel or a SequentialCuration not an object of type {type(curation_dict_or_model)}"
if isinstance(curation_dict_or_model, dict):
curation_model = CurationModel(**curation_dict_or_model)
elif isinstance(curation_dict_or_model, list):
curation_model = SequentialCuration(curation_steps=curation_dict_or_model)
else:
curation_model = curation_dict_or_model.model_copy(deep=True)

if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids):
if isinstance(curation_model, SequentialCuration):
for c, single_curation_model in enumerate(curation_model.curation_steps):
if verbose:
print(f"Applying curation step: {c + 1} / {len(curation_model.curation_steps)}")
sorting_or_analyzer = apply_curation(
sorting_or_analyzer,
single_curation_model,
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
raise_error_if_overlap_fails=raise_error_if_overlap_fails,
verbose=verbose,
job_kwargs=job_kwargs,
)
return sorting_or_analyzer

if not set(curation_model.unit_ids) == set(sorting_or_analyzer.unit_ids):
raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer")

# 1. Apply labels
Expand All @@ -228,13 +246,15 @@ def apply_curation(
curated_sorting_or_analyzer, _, _ = apply_merges_to_sorting(
curated_sorting_or_analyzer,
merge_unit_groups=merge_unit_groups,
new_unit_ids=merge_new_unit_ids,
censor_ms=censor_ms,
new_id_strategy=new_id_strategy,
return_extra=True,
)
else:
curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.merge_units(
merge_unit_groups=merge_unit_groups,
new_unit_ids=merge_new_unit_ids,
censor_ms=censor_ms,
merging_mode=merging_mode,
sparsity_overlap=sparsity_overlap,
Expand Down
83 changes: 82 additions & 1 deletion src/spikeinterface/curation/curation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def check_merges(cls, values):

# Check new unit id not already used
if merge.new_unit_id is not None:
if merge.new_unit_id in unit_ids:
if merge.new_unit_id in unit_ids and merge.new_unit_id not in merge.unit_ids:
raise ValueError(f"New unit ID {merge.new_unit_id} is already in the unit list")

values["merges"] = merges
Expand Down Expand Up @@ -366,6 +366,47 @@ def convert_old_format(cls, values):
values["removed"] = list(removed_units)
return values

def get_final_ids_from_new_unit_ids(self) -> list:
"""
Returns the final unit ids of the `curation_model`, when new unit ids are
given for each curation choice. Raises an error if new unit ids are missing
for any curation choice.

Returns
-------
final_ids : list
The ids of the sorting/analyzer after curation takes place
"""

intial_ids = self.unit_ids

ids_to_add = []
ids_to_remove = []
ids_to_remove = ids_to_remove + self.removed

for split in self.splits:
if split.new_unit_ids is None:
raise ValueError(
f"The `new_unit_ids` for the split of unit {split.unit_id} is `None`. These must be given."
)
ids_to_remove.append(split.unit_id)
ids_to_add = ids_to_add + split.new_unit_ids
for merge in self.merges:
if merge.new_unit_id is None:
raise ValueError(
f"The `new_unit_id` for the merge of units {merge.unit_ids} is `None`. This must be given."
)
ids_to_remove = ids_to_remove + merge.unit_ids
ids_to_add.append(merge.new_unit_id)

if ids_to_add is None:
ids_to_add = set()
if ids_to_remove is None:
ids_to_remove = set()

final_ids = set(intial_ids).difference(ids_to_remove).union(ids_to_add)
return list(final_ids)

@model_validator(mode="before")
def validate_fields(cls, values):
values = dict(values)
Expand Down Expand Up @@ -430,3 +471,43 @@ def validate_curation_dict(self):
)

return self


class SequentialCuration(BaseModel):
"""
A Pydantic model which defines a sequence of curation steps. If using sequential curations,
we demand that each individual curation (except the final one) has manually defined new unit ids,
and that these match the unit ids of the following curation.
"""

curation_steps: List[CurationModel] = Field(description="List of curation steps applied sequentially")

@model_validator(mode="after")
def validate_sequential_curation(self):

for curation in self.curation_steps[:-1]:
for merge in curation.merges:
if merge.new_unit_id is None:
raise ValueError(
"In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined."
)
for split in curation.splits:
if split.new_unit_ids is None:
raiseValueError(
"In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined."
)

for curation_index in range(len(self.curation_steps))[:-1]:

curation_1 = self.curation_steps[curation_index]
curation_2 = self.curation_steps[curation_index + 1]

previous_model_final_ids = curation_1.get_final_ids_from_new_unit_ids()
next_model_initial_ids = curation_2.unit_ids

if not (set(previous_model_final_ids) == set(next_model_initial_ids)):
raise ValueError(
f"The initial unit_ids of curation {curation_index+1} do not match the final unit_ids of curation {curation_index}."
)

return self
41 changes: 41 additions & 0 deletions src/spikeinterface/curation/tests/test_curation_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,27 @@
# This is a failure because unit 99 is not in the initial list
unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]}

# Sequential curation test data
sequential_curation = [
{
"format_version": "2",
"unit_ids": [1, 2, 3, 4, 5],
"merges": [{"unit_ids": [3, 4], "new_unit_id": 34}],
},
{
"format_version": "2",
"unit_ids": [1, 2, 34, 5],
"splits": [{"unit_id": 34, "mode": "indices", "indices": [[0, 1, 2, 3]], "new_unit_ids": [340, 341]}],
},
{
"format_version": "2",
"unit_ids": [1, 2, 340, 341, 5],
"removed": [2, 5],
"merges": [{"unit_ids": [1, 340], "new_unit_id": 100}],
"splits": [{"unit_id": 341, "mode": "indices", "indices": [[0, 1, 2]], "new_unit_ids": [3410, 3411]}],
},
]


def test_curation_format_validation():
# Test basic formats
Expand Down Expand Up @@ -412,6 +433,26 @@ def test_apply_curation_splits_with_mask():
assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder


def test_apply_sequential_curation():
recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=5, seed=2205)
sorting = sorting.rename_units([1, 2, 3, 4, 5])
analyzer = create_sorting_analyzer(sorting, recording, sparse=False)

# sequential curation steps:
# 1. merge 3 and 4 -> 34
# 2. split 34 -> 340, 341
# 3. remove 2, 5; merge 1 and 340 -> 100; split 341 -> 3410, 3411
analyzer_curated = apply_curation(analyzer, sequential_curation, verbose=True)
# initial -1(merge) +1(split) -2(remove) -1(merge) +1(split)
num_final_units = analyzer.get_num_units() - 1 + 1 - 2 - 1 + 1
assert analyzer_curated.get_num_units() == num_final_units

# check final unit ids
final_unit_ids = analyzer_curated.sorting.unit_ids
expected_final_unit_ids = [100, 3410, 3411]
assert set(final_unit_ids) == set(expected_final_unit_ids)


if __name__ == "__main__":
test_curation_format_validation()
test_to_from_json()
Expand Down
32 changes: 31 additions & 1 deletion src/spikeinterface/curation/tests/test_curation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import ValidationError
import numpy as np

from spikeinterface.curation.curation_model import CurationModel, LabelDefinition
from spikeinterface.curation.curation_model import CurationModel, SequentialCuration, LabelDefinition


# Test data for format version
Expand Down Expand Up @@ -282,3 +282,33 @@ def test_complete_model():
assert len(model.merges) == 1
assert len(model.splits) == 1
assert len(model.removed) == 1


def test_sequential_curation():
sequential_curation_steps_valid = [
{"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [{"unit_ids": [1, 2], "new_unit_id": 22}]},
{
"format_version": "2",
"unit_ids": [3, 4, 22],
"splits": [
{"unit_id": 22, "mode": "indices", "indices": [[0, 1, 2], [3, 4, 5]], "new_unit_ids": [222, 223]}
],
},
{"format_version": "2", "unit_ids": [3, 4, 222, 223], "removed": [223]},
]

# this is valid
SequentialCuration(curation_steps=sequential_curation_steps_valid)

sequential_curation_steps_no_ids = sequential_curation_steps_valid.copy()
# remove new_unit_id in merge step
sequential_curation_steps_no_ids[0]["merges"][0]["new_unit_id"] = None

with pytest.raises(ValidationError):
SequentialCuration(curation_steps=sequential_curation_steps_no_ids)

sequential_curation_steps_invalid = sequential_curation_steps_valid.copy()
# invalid unit_ids in last step
sequential_curation_steps_invalid[2]["unit_ids"] = [3, 4, 222, 224] # 224 should be 223
with pytest.raises(ValidationError):
SequentialCuration(curation_steps=sequential_curation_steps_invalid)
8 changes: 8 additions & 0 deletions src/spikeinterface/metrics/quality/quality_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def _handle_backward_compatibility_on_load(self):
if "mahalanobis" not in self.params["metric_names"]:
self.params["metric_names"].append("mahalanobis")

if "amplitude_cutoff" in self.params["metric_names"]:
if "peak_sign" in self.params["metric_params"]["amplitude_cutoff"]:
del self.params["metric_params"]["amplitude_cutoff"]["peak_sign"]

if "amplitude_median" in self.params["metric_names"]:
if "peak_sign" in self.params["metric_params"]["amplitude_median"]:
del self.params["metric_params"]["amplitude_median"]["peak_sign"]

def _set_params(
self,
metric_names: list[str] | None = None,
Expand Down
29 changes: 29 additions & 0 deletions src/spikeinterface/metrics/template/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,48 @@ def _handle_backward_compatibility_on_load(self):
self.params["metric_names"].remove("num_positive_peaks")
if "number_of_peaks" not in self.params["metric_names"]:
self.params["metric_names"].append("number_of_peaks")
self.params["metric_params"]["number_of_peaks"] = self.params["metric_params"]["num_positive_peaks"]
if "num_negative_peaks" in self.params["metric_names"]:
self.params["metric_names"].remove("num_negative_peaks")
if "number_of_peaks" not in self.params["metric_names"]:
self.params["metric_names"].append("number_of_peaks")

# velocity_above/velocity_below merged into velocity_fits
if "velocity_above" in self.params["metric_names"]:
self.params["metric_names"].remove("velocity_above")
if "velocity_fits" not in self.params["metric_names"]:
self.params["metric_names"].append("velocity_fits")

# also update parameters
if "velocity_above" in self.params["metric_params"]:
self.params["metric_params"]["velocity_fits"] = self.params["metric_params"]["velocity_above"]
self.params["metric_params"]["velocity_fits"]["min_channels"] = self.params["metric_params"][
"velocity_above"
]["min_channels_for_velocity"]
self.params["metric_params"]["velocity_fits"]["min_r2"] = self.params["metric_params"][
"velocity_above"
]["min_r2_velocity"]
del self.params["metric_params"]["velocity_above"]

if "velocity_below" in self.params["metric_names"]:
self.params["metric_names"].remove("velocity_below")
if "velocity_fits" not in self.params["metric_names"]:
self.params["metric_names"].append("velocity_fits")
if "velocity_below" in self.params["metric_params"]:
del self.params["metric_params"]["velocity_below"]

if "exp_decay" in self.params["metric_names"]:
if "exp_peak_function" in self.params["metric_params"]["exp_decay"]:
self.params["metric_params"]["exp_decay"]["peak_function"] = self.params["metric_params"]["exp_decay"][
"exp_peak_function"
]
if "min_r2_exp_decay" in self.params["metric_params"]["exp_decay"]:
self.params["metric_params"]["exp_decay"]["min_r2"] = self.params["metric_params"]["exp_decay"][
"min_r2_exp_decay"
]

if "depth_direction" not in self.params:
self.params["depth_direction"] = "y"

def _set_params(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def split_clusters(
peak_labels = peak_labels.copy()
split_count = np.zeros(peak_labels.size, dtype=int)
recursion_level = 1
Executor = get_poolexecutor(n_jobs)
executor = get_poolexecutor(n_jobs)

with Executor(
with executor(
max_workers=n_jobs,
initializer=split_worker_init,
mp_context=get_context(method=mp_context),
Expand Down Expand Up @@ -166,7 +166,6 @@ def split_worker_init(
_ctx = {}

_ctx["recording"] = recording
features_dict_or_folder
_ctx["original_labels"] = original_labels
_ctx["method"] = method
_ctx["method_kwargs"] = method_kwargs
Expand Down