diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index d686e7f175..7833a2aaa2 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -604,8 +604,9 @@ def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, arr = np.std(wfs, axis=0) elif operator == "median": arr = np.median(wfs, axis=0) - elif "percentile" in operator: - _, percentile = operator.splot("_") + # there was a spelling error "pencentile" in old versions + elif "percentile" in operator or "pencentile" in operator: + _, percentile = operator.split("_") arr = np.percentile(wfs, float(percentile), axis=0) new_array[split_unit_index, ...] = arr else: @@ -824,7 +825,7 @@ class BaseMetric: metric_descriptions = {} # descriptions of each metric column needs_recording = False # whether the metric needs recording needs_tmp_data = ( - False # whether the metric needs temporary data comoputed with _prepare_data at the MetricExtension level + False # whether the metric needs temporary data computed with _prepare_data at the MetricExtension level ) needs_job_kwargs = False depend_on = [] # extensions the metric depends on diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 1870c24e7a..e8ad30329a 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -2238,6 +2238,7 @@ class AnalyzerExtension: * _run() * _select_extension_data() * _merge_extension_data() + * _split_extension_data() * _get_data() The subclass must also set an `extension_name` class attribute which is not None by default. diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 840ef07353..cda01f5313 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -6,7 +6,7 @@ from itertools import chain from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting -from spikeinterface.curation.curation_model import CurationModel +from spikeinterface.curation.curation_model import CurationModel, SequentialCuration def validate_curation_dict(curation_dict: dict): @@ -138,7 +138,7 @@ def apply_curation_labels( def apply_curation( sorting_or_analyzer: BaseSorting | SortingAnalyzer, - curation_dict_or_model: dict | CurationModel, + curation_dict_or_model: dict | list | CurationModel | SequentialCuration, censor_ms: float | None = None, new_id_strategy: str = "append", merging_mode: str = "soft", @@ -164,7 +164,7 @@ def apply_curation( ---------- sorting_or_analyzer : Sorting | SortingAnalyzer The Sorting or SortingAnalyzer object to apply merges. - curation_dict : dict or CurationModel + curation_dict : dict | CurationModel | SequentialCuration The curation dict or model. censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of @@ -199,14 +199,32 @@ def apply_curation( sorting_or_analyzer, (BaseSorting, SortingAnalyzer) ), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" assert isinstance( - curation_dict_or_model, (dict, CurationModel) - ), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type {type(curation_dict_or_model)}" + curation_dict_or_model, (dict, list, CurationModel, SequentialCuration) + ), f"`curation_dict_or_model` must be a dict, CurationModel or a SequentialCuration not an object of type {type(curation_dict_or_model)}" if isinstance(curation_dict_or_model, dict): curation_model = CurationModel(**curation_dict_or_model) + elif isinstance(curation_dict_or_model, list): + curation_model = SequentialCuration(curation_steps=curation_dict_or_model) else: curation_model = curation_dict_or_model.model_copy(deep=True) - if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): + if isinstance(curation_model, SequentialCuration): + for c, single_curation_model in enumerate(curation_model.curation_steps): + if verbose: + print(f"Applying curation step: {c + 1} / {len(curation_model.curation_steps)}") + sorting_or_analyzer = apply_curation( + sorting_or_analyzer, + single_curation_model, + censor_ms=censor_ms, + merging_mode=merging_mode, + sparsity_overlap=sparsity_overlap, + raise_error_if_overlap_fails=raise_error_if_overlap_fails, + verbose=verbose, + job_kwargs=job_kwargs, + ) + return sorting_or_analyzer + + if not set(curation_model.unit_ids) == set(sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") # 1. Apply labels @@ -228,6 +246,7 @@ def apply_curation( curated_sorting_or_analyzer, _, _ = apply_merges_to_sorting( curated_sorting_or_analyzer, merge_unit_groups=merge_unit_groups, + new_unit_ids=merge_new_unit_ids, censor_ms=censor_ms, new_id_strategy=new_id_strategy, return_extra=True, @@ -235,6 +254,7 @@ def apply_curation( else: curated_sorting_or_analyzer, _ = curated_sorting_or_analyzer.merge_units( merge_unit_groups=merge_unit_groups, + new_unit_ids=merge_new_unit_ids, censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index ac89fde04a..3e9a6d9285 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -190,7 +190,7 @@ def check_merges(cls, values): # Check new unit id not already used if merge.new_unit_id is not None: - if merge.new_unit_id in unit_ids: + if merge.new_unit_id in unit_ids and merge.new_unit_id not in merge.unit_ids: raise ValueError(f"New unit ID {merge.new_unit_id} is already in the unit list") values["merges"] = merges @@ -366,6 +366,47 @@ def convert_old_format(cls, values): values["removed"] = list(removed_units) return values + def get_final_ids_from_new_unit_ids(self) -> list: + """ + Returns the final unit ids of the `curation_model`, when new unit ids are + given for each curation choice. Raises an error if new unit ids are missing + for any curation choice. + + Returns + ------- + final_ids : list + The ids of the sorting/analyzer after curation takes place + """ + + intial_ids = self.unit_ids + + ids_to_add = [] + ids_to_remove = [] + ids_to_remove = ids_to_remove + self.removed + + for split in self.splits: + if split.new_unit_ids is None: + raise ValueError( + f"The `new_unit_ids` for the split of unit {split.unit_id} is `None`. These must be given." + ) + ids_to_remove.append(split.unit_id) + ids_to_add = ids_to_add + split.new_unit_ids + for merge in self.merges: + if merge.new_unit_id is None: + raise ValueError( + f"The `new_unit_id` for the merge of units {merge.unit_ids} is `None`. This must be given." + ) + ids_to_remove = ids_to_remove + merge.unit_ids + ids_to_add.append(merge.new_unit_id) + + if ids_to_add is None: + ids_to_add = set() + if ids_to_remove is None: + ids_to_remove = set() + + final_ids = set(intial_ids).difference(ids_to_remove).union(ids_to_add) + return list(final_ids) + @model_validator(mode="before") def validate_fields(cls, values): values = dict(values) @@ -430,3 +471,43 @@ def validate_curation_dict(self): ) return self + + +class SequentialCuration(BaseModel): + """ + A Pydantic model which defines a sequence of curation steps. If using sequential curations, + we demand that each individual curation (except the final one) has manually defined new unit ids, + and that these match the unit ids of the following curation. + """ + + curation_steps: List[CurationModel] = Field(description="List of curation steps applied sequentially") + + @model_validator(mode="after") + def validate_sequential_curation(self): + + for curation in self.curation_steps[:-1]: + for merge in curation.merges: + if merge.new_unit_id is None: + raise ValueError( + "In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined." + ) + for split in curation.splits: + if split.new_unit_ids is None: + raiseValueError( + "In a sequential curation, all curation decisions must have explicit `new_unit_id`s defined." + ) + + for curation_index in range(len(self.curation_steps))[:-1]: + + curation_1 = self.curation_steps[curation_index] + curation_2 = self.curation_steps[curation_index + 1] + + previous_model_final_ids = curation_1.get_final_ids_from_new_unit_ids() + next_model_initial_ids = curation_2.unit_ids + + if not (set(previous_model_final_ids) == set(next_model_initial_ids)): + raise ValueError( + f"The initial unit_ids of curation {curation_index+1} do not match the final unit_ids of curation {curation_index}." + ) + + return self diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 85df3c57bf..40e0a9b2b7 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -171,6 +171,27 @@ # This is a failure because unit 99 is not in the initial list unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]} +# Sequential curation test data +sequential_curation = [ + { + "format_version": "2", + "unit_ids": [1, 2, 3, 4, 5], + "merges": [{"unit_ids": [3, 4], "new_unit_id": 34}], + }, + { + "format_version": "2", + "unit_ids": [1, 2, 34, 5], + "splits": [{"unit_id": 34, "mode": "indices", "indices": [[0, 1, 2, 3]], "new_unit_ids": [340, 341]}], + }, + { + "format_version": "2", + "unit_ids": [1, 2, 340, 341, 5], + "removed": [2, 5], + "merges": [{"unit_ids": [1, 340], "new_unit_id": 100}], + "splits": [{"unit_id": 341, "mode": "indices", "indices": [[0, 1, 2]], "new_unit_ids": [3410, 3411]}], + }, +] + def test_curation_format_validation(): # Test basic formats @@ -412,6 +433,26 @@ def test_apply_curation_splits_with_mask(): assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder +def test_apply_sequential_curation(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=5, seed=2205) + sorting = sorting.rename_units([1, 2, 3, 4, 5]) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + # sequential curation steps: + # 1. merge 3 and 4 -> 34 + # 2. split 34 -> 340, 341 + # 3. remove 2, 5; merge 1 and 340 -> 100; split 341 -> 3410, 3411 + analyzer_curated = apply_curation(analyzer, sequential_curation, verbose=True) + # initial -1(merge) +1(split) -2(remove) -1(merge) +1(split) + num_final_units = analyzer.get_num_units() - 1 + 1 - 2 - 1 + 1 + assert analyzer_curated.get_num_units() == num_final_units + + # check final unit ids + final_unit_ids = analyzer_curated.sorting.unit_ids + expected_final_unit_ids = [100, 3410, 3411] + assert set(final_unit_ids) == set(expected_final_unit_ids) + + if __name__ == "__main__": test_curation_format_validation() test_to_from_json() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 951f33a300..147dfddf8a 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -3,7 +3,7 @@ from pydantic import ValidationError import numpy as np -from spikeinterface.curation.curation_model import CurationModel, LabelDefinition +from spikeinterface.curation.curation_model import CurationModel, SequentialCuration, LabelDefinition # Test data for format version @@ -282,3 +282,33 @@ def test_complete_model(): assert len(model.merges) == 1 assert len(model.splits) == 1 assert len(model.removed) == 1 + + +def test_sequential_curation(): + sequential_curation_steps_valid = [ + {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [{"unit_ids": [1, 2], "new_unit_id": 22}]}, + { + "format_version": "2", + "unit_ids": [3, 4, 22], + "splits": [ + {"unit_id": 22, "mode": "indices", "indices": [[0, 1, 2], [3, 4, 5]], "new_unit_ids": [222, 223]} + ], + }, + {"format_version": "2", "unit_ids": [3, 4, 222, 223], "removed": [223]}, + ] + + # this is valid + SequentialCuration(curation_steps=sequential_curation_steps_valid) + + sequential_curation_steps_no_ids = sequential_curation_steps_valid.copy() + # remove new_unit_id in merge step + sequential_curation_steps_no_ids[0]["merges"][0]["new_unit_id"] = None + + with pytest.raises(ValidationError): + SequentialCuration(curation_steps=sequential_curation_steps_no_ids) + + sequential_curation_steps_invalid = sequential_curation_steps_valid.copy() + # invalid unit_ids in last step + sequential_curation_steps_invalid[2]["unit_ids"] = [3, 4, 222, 224] # 224 should be 223 + with pytest.raises(ValidationError): + SequentialCuration(curation_steps=sequential_curation_steps_invalid) diff --git a/src/spikeinterface/metrics/quality/quality_metrics.py b/src/spikeinterface/metrics/quality/quality_metrics.py index 8e96f4dcaf..67012ff33f 100644 --- a/src/spikeinterface/metrics/quality/quality_metrics.py +++ b/src/spikeinterface/metrics/quality/quality_metrics.py @@ -64,6 +64,14 @@ def _handle_backward_compatibility_on_load(self): if "mahalanobis" not in self.params["metric_names"]: self.params["metric_names"].append("mahalanobis") + if "amplitude_cutoff" in self.params["metric_names"]: + if "peak_sign" in self.params["metric_params"]["amplitude_cutoff"]: + del self.params["metric_params"]["amplitude_cutoff"]["peak_sign"] + + if "amplitude_median" in self.params["metric_names"]: + if "peak_sign" in self.params["metric_params"]["amplitude_median"]: + del self.params["metric_params"]["amplitude_median"]["peak_sign"] + def _set_params( self, metric_names: list[str] | None = None, diff --git a/src/spikeinterface/metrics/template/template_metrics.py b/src/spikeinterface/metrics/template/template_metrics.py index 83a9048a64..9f66266417 100644 --- a/src/spikeinterface/metrics/template/template_metrics.py +++ b/src/spikeinterface/metrics/template/template_metrics.py @@ -112,19 +112,48 @@ def _handle_backward_compatibility_on_load(self): self.params["metric_names"].remove("num_positive_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") + self.params["metric_params"]["number_of_peaks"] = self.params["metric_params"]["num_positive_peaks"] if "num_negative_peaks" in self.params["metric_names"]: self.params["metric_names"].remove("num_negative_peaks") if "number_of_peaks" not in self.params["metric_names"]: self.params["metric_names"].append("number_of_peaks") + # velocity_above/velocity_below merged into velocity_fits if "velocity_above" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_above") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") + + # also update parameters + if "velocity_above" in self.params["metric_params"]: + self.params["metric_params"]["velocity_fits"] = self.params["metric_params"]["velocity_above"] + self.params["metric_params"]["velocity_fits"]["min_channels"] = self.params["metric_params"][ + "velocity_above" + ]["min_channels_for_velocity"] + self.params["metric_params"]["velocity_fits"]["min_r2"] = self.params["metric_params"][ + "velocity_above" + ]["min_r2_velocity"] + del self.params["metric_params"]["velocity_above"] + if "velocity_below" in self.params["metric_names"]: self.params["metric_names"].remove("velocity_below") if "velocity_fits" not in self.params["metric_names"]: self.params["metric_names"].append("velocity_fits") + if "velocity_below" in self.params["metric_params"]: + del self.params["metric_params"]["velocity_below"] + + if "exp_decay" in self.params["metric_names"]: + if "exp_peak_function" in self.params["metric_params"]["exp_decay"]: + self.params["metric_params"]["exp_decay"]["peak_function"] = self.params["metric_params"]["exp_decay"][ + "exp_peak_function" + ] + if "min_r2_exp_decay" in self.params["metric_params"]["exp_decay"]: + self.params["metric_params"]["exp_decay"]["min_r2"] = self.params["metric_params"]["exp_decay"][ + "min_r2_exp_decay" + ] + + if "depth_direction" not in self.params: + self.params["depth_direction"] = "y" def _set_params( self, diff --git a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py index 0411166de6..e1a03f23f8 100644 --- a/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/itersplit_tools.py @@ -74,9 +74,9 @@ def split_clusters( peak_labels = peak_labels.copy() split_count = np.zeros(peak_labels.size, dtype=int) recursion_level = 1 - Executor = get_poolexecutor(n_jobs) + executor = get_poolexecutor(n_jobs) - with Executor( + with executor( max_workers=n_jobs, initializer=split_worker_init, mp_context=get_context(method=mp_context), @@ -166,7 +166,6 @@ def split_worker_init( _ctx = {} _ctx["recording"] = recording - features_dict_or_folder _ctx["original_labels"] = original_labels _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs