diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index fea3f3618e..eda660a842 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -93,6 +93,22 @@ def _merge_extension_data( new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask]) return new_data + def _frame_slice_extension_data( + self, + start_frame, + end_frame, + ): + + new_data = dict() + random_spikes_indices = self.data["random_spikes_indices"] + spike_vector = self.sorting_analyzer.sorting.to_spike_vector() + first_spike_in_sliced_range = np.searchsorted(spike_vector["sample_index"], start_frame) + sample_indices = spike_vector[random_spikes_indices]["sample_index"] + indices_of_kept_spikes = np.where((sample_indices < end_frame) & (sample_indices > start_frame)) + new_data["random_spikes_indices"] = random_spikes_indices[indices_of_kept_spikes] - first_spike_in_sliced_range + + return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): new_data = dict() new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy() diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 70ad78353f..c596bc9914 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1578,6 +1578,42 @@ def get_dtype(self): def get_num_units(self) -> int: return self.sorting.get_num_units() + def frame_slice(self, start_frame=None, end_frame=None, slice_mode="hard", **job_kwargs): + """ + Do a frame slice. + """ + + assert slice_mode in ["soft", "hard"] + + sorting = self.sorting.frame_slice(start_frame=start_frame, end_frame=end_frame) + recording = self.recording.frame_slice(start_frame=start_frame, end_frame=end_frame) + sparsity = self.sparsity + + new_sorting_analyzer = SortingAnalyzer.create_memory( + sorting, recording, sparsity, self.return_in_uV, self.rec_attributes + ) + + sorted_extensions = _sort_extensions_by_dependency(self.extensions) + qm_extension_params = sorted_extensions.pop("quality_metrics", None) + if qm_extension_params is not None: + sorted_extensions["quality_metrics"] = qm_extension_params + + if slice_mode == "hard": + extensions_dict = { + extension_name: extension.params for extension_name, extension in sorted_extensions.items() + } + new_sorting_analyzer.compute(extensions_dict) + elif slice_mode == "soft": + for extension_name, extension in sorted_extensions.items(): + new_sorting_analyzer.extensions[extension_name] = extension.frame_slice( + new_sorting_analyzer, + start_frame=start_frame, + end_frame=end_frame, + **job_kwargs, + ) + + return new_sorting_analyzer + ## extensions zone def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs) -> "AnalyzerExtension | None": """ @@ -2232,6 +2268,10 @@ def _merge_extension_data( # must be implemented in subclass raise NotImplementedError + def _frame_slice_extension_data(self, start_frame, end_frame, **job_kwargs): + # must be implemented in subclass + raise NotImplementedError + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): # must be implemented in subclass raise NotImplementedError @@ -2472,6 +2512,20 @@ def merge( new_extension.save() return new_extension + def frame_slice( + self, + new_sorting_analyzer, + start_frame, + end_frame, + **job_kwargs, + ): + new_extension = self.__class__(new_sorting_analyzer) + new_extension.params = self.params.copy() + new_extension.data = self._frame_slice_extension_data(start_frame, end_frame, **job_kwargs) + new_extension.run_info = copy(self.run_info) + new_extension.save() + return new_extension + def split( self, new_sorting_analyzer,