Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ def _merge_extension_data(
new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask])
return new_data

def _frame_slice_extension_data(
self,
start_frame,
end_frame,
):

new_data = dict()
random_spikes_indices = self.data["random_spikes_indices"]
spike_vector = self.sorting_analyzer.sorting.to_spike_vector()
first_spike_in_sliced_range = np.searchsorted(spike_vector["sample_index"], start_frame)
sample_indices = spike_vector[random_spikes_indices]["sample_index"]
indices_of_kept_spikes = np.where((sample_indices < end_frame) & (sample_indices > start_frame))
new_data["random_spikes_indices"] = random_spikes_indices[indices_of_kept_spikes] - first_spike_in_sliced_range

return new_data

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
new_data = dict()
new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy()
Expand Down
54 changes: 54 additions & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,42 @@ def get_dtype(self):
def get_num_units(self) -> int:
return self.sorting.get_num_units()

def frame_slice(self, start_frame=None, end_frame=None, slice_mode="hard", **job_kwargs):
"""
Do a frame slice.
"""

assert slice_mode in ["soft", "hard"]

sorting = self.sorting.frame_slice(start_frame=start_frame, end_frame=end_frame)
recording = self.recording.frame_slice(start_frame=start_frame, end_frame=end_frame)
sparsity = self.sparsity

new_sorting_analyzer = SortingAnalyzer.create_memory(
sorting, recording, sparsity, self.return_in_uV, self.rec_attributes
)

sorted_extensions = _sort_extensions_by_dependency(self.extensions)
qm_extension_params = sorted_extensions.pop("quality_metrics", None)
if qm_extension_params is not None:
sorted_extensions["quality_metrics"] = qm_extension_params

if slice_mode == "hard":
extensions_dict = {
extension_name: extension.params for extension_name, extension in sorted_extensions.items()
}
new_sorting_analyzer.compute(extensions_dict)
elif slice_mode == "soft":
for extension_name, extension in sorted_extensions.items():
new_sorting_analyzer.extensions[extension_name] = extension.frame_slice(
new_sorting_analyzer,
start_frame=start_frame,
end_frame=end_frame,
**job_kwargs,
)

return new_sorting_analyzer

## extensions zone
def compute(self, input, save=True, extension_params=None, verbose=False, **kwargs) -> "AnalyzerExtension | None":
"""
Expand Down Expand Up @@ -2232,6 +2268,10 @@ def _merge_extension_data(
# must be implemented in subclass
raise NotImplementedError

def _frame_slice_extension_data(self, start_frame, end_frame, **job_kwargs):
# must be implemented in subclass
raise NotImplementedError

def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
# must be implemented in subclass
raise NotImplementedError
Expand Down Expand Up @@ -2472,6 +2512,20 @@ def merge(
new_extension.save()
return new_extension

def frame_slice(
self,
new_sorting_analyzer,
start_frame,
end_frame,
**job_kwargs,
):
new_extension = self.__class__(new_sorting_analyzer)
new_extension.params = self.params.copy()
new_extension.data = self._frame_slice_extension_data(start_frame, end_frame, **job_kwargs)
new_extension.run_info = copy(self.run_info)
new_extension.save()
return new_extension

def split(
self,
new_sorting_analyzer,
Expand Down
Loading