diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index d686e7f175..77d877f4e7 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -17,7 +17,7 @@ from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates -from .sorting_tools import random_spikes_selection +from .sorting_tools import random_spikes_selection, select_sorting_periods_mask, spike_vector_to_indices from .job_tools import fix_job_kwargs, split_job_kwargs @@ -1343,6 +1343,9 @@ class BaseSpikeVectorExtension(AnalyzerExtension): need_backward_compatibility_on_load = False nodepipeline_variables = [] # to be defined in subclass + def __init__(self, sorting_analyzer): + super().__init__(sorting_analyzer) + def _set_params(self, **kwargs): params = kwargs.copy() return params @@ -1381,7 +1384,7 @@ def _run(self, verbose=False, **job_kwargs): for d, name in zip(data, data_names): self.data[name] = d - def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True): + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True): """ Return extension data. If the extension computes more than one `nodepipeline_variables`, the `return_data_name` is used to specify which one to return. @@ -1395,13 +1398,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return_data_name : str | None, default: None The name of the data to return. If None and multiple `nodepipeline_variables` are computed, the first one is returned. + periods : array of unit_period dtype, default: None + Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data copy : bool, default: True Whether to return a copy of the data (only for outputs="numpy") Returns ------- numpy.ndarray | dict - The + The requested data in numpy or by unit format. """ if len(self.nodepipeline_variables) == 1: @@ -1415,6 +1420,14 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + keep_mask = None + if periods is not None: + keep_mask = select_sorting_periods_mask( + self.sorting_analyzer.sorting, + periods, + ) + all_data = all_data[keep_mask] + if outputs == "numpy": if copy: return all_data.copy() # return a copy to avoid modification @@ -1422,8 +1435,14 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return all_data elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids - # use the cache of indices - spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices() + spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + if keep_mask is not None: + # since we are filtering spikes, we need to recompute the spike indices + spike_vector = spike_vector[keep_mask] + spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) + else: + # use the cache of indices + spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices() data_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): data_by_units[segment_index] = {} diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3505853835..4520c19819 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -42,9 +42,9 @@ minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] base_period_dtype = [ + ("segment_index", "int64"), ("start_sample_index", "int64"), ("end_sample_index", "int64"), - ("segment_index", "int64"), ] unit_period_dtype = base_period_dtype + [ diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 5ab7fd0af3..42f98b3473 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -639,6 +639,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def select_periods(self, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + periods : numpy.array of unit_period_dtype + Period (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.sorting_tools import select_sorting_periods + + return select_sorting_periods(self, periods) + def split_by(self, property="group", outputs="dict"): """ Splits object based on a certain property (e.g. "group") diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..bc0a1871af 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -228,6 +228,102 @@ def random_spikes_selection( return random_spikes_indices +def select_sorting_periods_mask(sorting: BaseSorting, periods): + """ + Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + sorting : BaseSorting + The sorting object. + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + numpy.array + A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. + """ + spike_vector = sorting.to_spike_vector() + keep_mask = np.zeros(len(spike_vector), dtype=bool) + all_global_indices = sorting.get_spike_vector_to_indices() + for segment_index in range(sorting.get_num_segments()): + global_indices_segment = all_global_indices[segment_index] + # filter periods by segment + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + # filter by unit index + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + global_indices = global_indices_segment[unit_id] + spiketrains = spike_vector[global_indices]["sample_index"] + if len(periods_for_unit) > 0: + for period in periods_for_unit: + mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"]) + keep_mask[global_indices[mask]] = True + return keep_mask + + +def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting: + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + periods : numpy.ndarray + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. Periods can be either a numpy array of unit_period_dtype + or an array with (num_periods, 4) shape. In the latter case, the fields are assumed to be + in the order: segment_index, start_sample_index, end_sample_index, unit_index. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.base import unit_period_dtype + from spikeinterface.core.numpyextractors import NumpySorting + + if periods is not None: + if not isinstance(periods, np.ndarray): + raise ValueError("periods must be a numpy array") + if not periods.dtype == unit_period_dtype: + if periods.ndim != 2 or periods.shape[1] != 4: + raise ValueError( + "If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)" + ) + warnings.warn( + "periods is not of dtype unit_period_dtype. Assuming fields are in order: " + "(segment_index, start_sample_index, end_sample_index, unit_index).", + UserWarning, + ) + # convert to structured array + periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype) + periods_converted["segment_index"] = periods[:, 0] + periods_converted["start_sample_index"] = periods[:, 1] + periods_converted["end_sample_index"] = periods[:, 2] + periods_converted["unit_index"] = periods[:, 3] + periods = periods_converted + + required = set(np.dtype(unit_period_dtype).names) + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + + spike_vector = sorting.to_spike_vector() + keep_mask = select_sorting_periods_mask(sorting, periods) + sliced_spike_vector = spike_vector[keep_mask] + + # important: we keep the original unit ids so the unit_index field in spike vector is still valid + sorting = NumpySorting( + sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids + ) + sorting.copy_metadata(sorting) + return sorting + else: + return sorting + + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 4f10be4c26..6c06b212b8 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -3,9 +3,7 @@ but check only for BaseRecording general methods. """ -import shutil -from pathlib import Path - +import time import numpy as np import pytest from numpy.testing import assert_raises @@ -17,15 +15,14 @@ SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, + generate_ground_truth_recording, + generate_sorting, create_sorting_npz, generate_sorting, load, ) -from spikeinterface.core.base import BaseExtractor +from spikeinterface.core.base import BaseExtractor, unit_period_dtype from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal -from spikeinterface.core.generate import generate_sorting - -from spikeinterface.core import generate_recording, generate_ground_truth_recording def test_BaseSorting(create_cache_folder): @@ -245,6 +242,74 @@ def test_time_slice(): ) +def test_select_periods(): + sampling_frequency = 10_000.0 + duration = 100 + num_samples = int(sampling_frequency * duration) + num_units = 1000 + sorting = generate_sorting( + durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units + ) + + rng = np.random.default_rng() + + # number of random periods + n_periods = 1_000 + # generate random periods + segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) + start_samples = rng.integers(0, num_samples, n_periods) + durations = rng.integers(100, 100_000, n_periods) + end_samples = start_samples + durations + valid_periods = end_samples < num_samples + segment_indices = segment_indices[valid_periods] + start_samples = start_samples[valid_periods] + end_samples = end_samples[valid_periods] + unit_index = rng.integers(0, num_units - 1, len(segment_indices)) + + periods = np.zeros(len(segment_indices), dtype=unit_period_dtype) + periods["segment_index"] = segment_indices + periods["start_sample_index"] = start_samples + periods["end_sample_index"] = end_samples + periods["unit_index"] = unit_index + periods = np.sort(periods, order=["segment_index", "start_sample_index"]) + + t_start = time.perf_counter() + sliced_sorting = sorting.select_periods(periods=periods) + t_stop = time.perf_counter() + elapsed = t_stop - t_start + print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods") + + # Check that all spikes in the sliced sorting are within the periods + for segment_index in range(sorting.get_num_segments()): + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + spiketrain_in_periods = [] + for period in periods_for_unit: + start_sample = period["start_sample_index"] + end_sample = period["end_sample_index"] + spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)]) + if len(spiketrain_in_periods) == 0: + spiketrain_in_periods = np.array([], dtype=spiketrain.dtype) + else: + spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods)) + + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + assert len(spiketrain_in_periods) == len(spiketrain_sliced) + + # now test with input as numpy array with shape (n_periods, 4) + periods_array = np.zeros((len(periods), 4), dtype="int64") + periods_array[:, 0] = periods["segment_index"] + periods_array[:, 1] = periods["start_sample_index"] + periods_array[:, 2] = periods["end_sample_index"] + periods_array[:, 3] = periods["unit_index"] + + sliced_sorting_array = sorting.select_periods(periods=periods_array) + np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) + + if __name__ == "__main__": import tempfile @@ -254,3 +319,4 @@ def test_time_slice(): test_BaseSorting(cache_folder) test_npy_sorting() test_empty_sorting() + test_select_periods() diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ab7ae9e7b5..75e41620f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -130,7 +130,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan recording._properties["contact_vector"][idx][1] = x[idx] # generate random bad channel locations - bad_channel_indexes = rng.choice(num_channels, rng.randint(1, int(num_channels / 5)), replace=False) + bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) bad_channel_ids = recording.channel_ids[bad_channel_indexes] # Run SI and IBL interpolation and check against eachother