diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 74ef52e258..804418a2ff 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -17,7 +17,7 @@ from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator from .recording_tools import get_noise_levels from .template import Templates -from .sorting_tools import random_spikes_selection +from .sorting_tools import random_spikes_selection, select_sorting_periods_mask from .job_tools import fix_job_kwargs, split_job_kwargs @@ -1331,6 +1331,21 @@ class BaseSpikeVectorExtension(AnalyzerExtension): need_backward_compatibility_on_load = False nodepipeline_variables = [] # to be defined in subclass + def __init__(self, sorting_analyzer): + super().__init__(sorting_analyzer) + self._segment_slices = None + + @property + def segment_slices(self): + if self._segment_slices is None: + segment_slices = [] + spikes = self.sorting_analyzer.sorting.to_spike_vector() + for segment_index in range(self.sorting_analyzer.get_num_segments()): + i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) + segment_slices.append(slice(i0, i1)) + self._segment_slices = segment_slices + return self._segment_slices + def _set_params(self, **kwargs): params = kwargs.copy() return params @@ -1369,7 +1384,7 @@ def _run(self, verbose=False, **job_kwargs): for d, name in zip(data, data_names): self.data[name] = d - def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True): + def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True): """ Return extension data. If the extension computes more than one `nodepipeline_variables`, the `return_data_name` is used to specify which one to return. @@ -1383,13 +1398,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, return_data_name : str | None, default: None The name of the data to return. If None and multiple `nodepipeline_variables` are computed, the first one is returned. + periods : array of unit_period dtype, default: None + Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data copy : bool, default: True Whether to return a copy of the data (only for outputs="numpy") Returns ------- numpy.ndarray | dict - The + The requested data in numpy or by unit format. """ from spikeinterface.core.sorting_tools import spike_vector_to_indices @@ -1404,6 +1421,14 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, ), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}" all_data = self.data[return_data_name] + keep_mask = None + if periods is not None: + keep_mask = select_sorting_periods_mask( + self.sorting_analyzer.sorting, + periods, + ) + all_data = all_data[keep_mask] + if outputs == "numpy": if copy: return all_data.copy() # return a copy to avoid modification @@ -1412,6 +1437,8 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, elif outputs == "by_unit": unit_ids = self.sorting_analyzer.unit_ids spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False) + if keep_mask is not None: + spike_vector = spike_vector[keep_mask] spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True) data_by_units = {} for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()): diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 98159fb646..b6440f8e2b 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -626,6 +626,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo return self.frame_slice(start_frame=start_frame, end_frame=end_frame) + def select_periods(self, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + periods : numpy.array of unit_period_dtype + Period (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.sorting_tools import select_sorting_periods + + return select_sorting_periods(self, periods) + def split_by(self, property="group", outputs="dict"): """ Splits object based on a certain property (e.g. "group") diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 71654a67b4..f6bf3cb31f 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -22,11 +22,20 @@ ("segment_index", "int64"), ] - spike_peak_dtype = base_peak_dtype + [ ("unit_index", "int64"), ] +base_period_dtype = [ + ("start_sample_index", "int64"), + ("end_sample_index", "int64"), + ("segment_index", "int64"), +] + +unit_period_dtype = base_period_dtype + [ + ("unit_index", "int64"), +] + class PipelineNode: diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 90c7e18a99..9a9a3670ef 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -228,6 +228,83 @@ def random_spikes_selection( return random_spikes_indices +def select_sorting_periods_mask(sorting: BaseSorting, periods): + """ + Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + sorting : BaseSorting + The sorting object. + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + numpy.array + A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods. + """ + spike_vector = sorting.to_spike_vector() + spike_vector_list = sorting.to_spike_vector(concatenated=False) + keep_mask = np.zeros(len(spike_vector), dtype=bool) + all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True) + for segment_index in range(sorting.get_num_segments()): + global_indices_segment = all_global_indices[segment_index] + # filter periods by segment + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + # filter by unit index + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + global_indices = global_indices_segment[unit_id] + spiketrains = spike_vector[global_indices]["sample_index"] + if len(periods_for_unit) > 0: + for period in periods_for_unit: + mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"]) + keep_mask[global_indices[mask]] = True + return keep_mask + + +def select_sorting_periods(sorting: BaseSorting, periods): + """ + Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype. + + Parameters + ---------- + S + periods : numpy.array of unit_period_dtype + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to restrict the sorting. + + Returns + ------- + BaseSorting + A new sorting object with only samples between start_sample_index and end_sample_index + for the given segment_index. + """ + from spikeinterface.core.numpyextractors import NumpySorting + from spikeinterface.core.node_pipeline import unit_period_dtype + + if periods is not None: + if not isinstance(periods, np.ndarray): + periods = np.array([periods], dtype=unit_period_dtype) + required = set(np.dtype(unit_period_dtype).names) + if not required.issubset(periods.dtype.names): + raise ValueError(f"Period must have the following fields: {required}") + + spike_vector = sorting.to_spike_vector() + keep_mask = select_sorting_periods_mask(sorting, periods) + sliced_spike_vector = spike_vector[keep_mask] + + sorting = NumpySorting( + sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids + ) + sorting.copy_metadata(sorting) + return sorting + else: + return sorting + + ### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 54befd40ec..18f632ed34 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -3,9 +3,7 @@ but check only for BaseRecording general methods. """ -import shutil -from pathlib import Path - +import time import numpy as np import pytest from numpy.testing import assert_raises @@ -17,15 +15,15 @@ SharedMemorySorting, NpzFolderSorting, NumpyFolderSorting, + generate_ground_truth_recording, + generate_sorting, create_sorting_npz, generate_sorting, load, ) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal -from spikeinterface.core.generate import generate_sorting - -from spikeinterface.core import generate_recording, generate_ground_truth_recording +from spikeinterface.core.node_pipeline import unit_period_dtype def test_BaseSorting(create_cache_folder): @@ -226,7 +224,66 @@ def test_time_slice(): ) +def test_select_periods(): + sampling_frequency = 10_000.0 + duration = 1_000 + num_samples = int(sampling_frequency * duration) + num_units = 1000 + sorting = generate_sorting( + durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units + ) + + rng = np.random.default_rng() + + # number of random periods + n_periods = 10_000 + # generate random periods + segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods) + start_samples = rng.integers(0, num_samples, n_periods) + durations = rng.integers(100, 100_000, n_periods) + end_samples = start_samples + durations + valid_periods = end_samples < num_samples + segment_indices = segment_indices[valid_periods] + start_samples = start_samples[valid_periods] + end_samples = end_samples[valid_periods] + unit_index = rng.integers(0, num_units - 1, len(segment_indices)) + + periods = np.zeros(len(segment_indices), dtype=unit_period_dtype) + periods["segment_index"] = segment_indices + periods["start_sample_index"] = start_samples + periods["end_sample_index"] = end_samples + periods["unit_index"] = unit_index + periods = np.sort(periods, order=["segment_index", "start_sample_index"]) + + t_start = time.perf_counter() + sliced_sorting = sorting.select_periods(periods=periods) + t_stop = time.perf_counter() + elapsed = t_stop - t_start + print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods") + + # Check that all spikes in the sliced sorting are within the periods + for segment_index in range(sorting.get_num_segments()): + periods_in_segment = periods[periods["segment_index"] == segment_index] + for unit_index, unit_id in enumerate(sorting.unit_ids): + spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + + periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index] + spiketrain_in_periods = [] + for period in periods_for_unit: + start_sample = period["start_sample_index"] + end_sample = period["end_sample_index"] + spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)]) + if len(spiketrain_in_periods) == 0: + spiketrain_in_periods = np.array([], dtype=spiketrain.dtype) + else: + spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods)) + + spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id) + assert len(spiketrain_in_periods) == len(spiketrain_sliced) + + if __name__ == "__main__": test_BaseSorting() test_npy_sorting() test_empty_sorting() + test_select_periods() diff --git a/src/spikeinterface/metrics/quality/misc_metrics.py b/src/spikeinterface/metrics/quality/misc_metrics.py index c6b07da52e..4a7ef04554 100644 --- a/src/spikeinterface/metrics/quality/misc_metrics.py +++ b/src/spikeinterface/metrics/quality/misc_metrics.py @@ -19,12 +19,13 @@ from spikeinterface.core.analyzer_extension_core import BaseMetric from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from spikeinterface.postprocessing import correlogram_for_one_segment -from spikeinterface.core import SortingAnalyzer, get_noise_levels +from spikeinterface.core import SortingAnalyzer, get_noise_levels, select_segment_sorting from spikeinterface.core.template_tools import ( get_template_extremum_channel, get_template_extremum_amplitude, get_dense_templates_array, ) +from spikeinterface.core.node_pipeline import base_period_dtype from ..spiketrain.metrics import NumSpikes, FiringRate @@ -35,7 +36,9 @@ HAVE_NUMBA = False -def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0): +def compute_presence_ratios( + sorting_analyzer, unit_ids=None, bin_duration_s=60.0, mean_fr_ratio_thresh=0.0, periods=None +): """ Calculate the presence ratio, the fraction of time the unit is firing above a certain threshold. @@ -51,6 +54,9 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 mean_fr_ratio_thresh : float, default: 0 The unit is considered active in a bin if its firing rate during that bin. is strictly above `mean_fr_ratio_thresh` times its mean firing rate throughout the recording. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -63,6 +69,7 @@ def compute_presence_ratios(sorting_analyzer, unit_ids=None, bin_duration_s=60.0 To do so, spike trains across segments are concatenated to mimic a continuous segment. """ sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -182,7 +189,7 @@ class SNR(BaseMetric): depend_on = ["noise_levels", "templates"] -def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0): +def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5, min_isi_ms=0, periods=None): """ Calculate Inter-Spike Interval (ISI) violations. @@ -204,6 +211,9 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 Minimum possible inter-spike interval, in ms. This is the artificial refractory period enforced. by the data acquisition system or post-processing algorithms. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -235,6 +245,7 @@ def compute_isi_violations(sorting_analyzer, unit_ids=None, isi_threshold_ms=1.5 res = namedtuple("isi_violation", ["isi_violations_ratio", "isi_violations_count"]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -280,7 +291,7 @@ class ISIViolation(BaseMetric): def compute_refrac_period_violations( - sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0 + sorting_analyzer, unit_ids=None, refractory_period_ms: float = 1.0, censored_period_ms: float = 0.0, periods=None ): """ Calculate the number of refractory period violations. @@ -300,6 +311,9 @@ def compute_refrac_period_violations( censored_period_ms : float, default: 0.0 The period (in ms) where no 2 spikes can occur (because they are not detected, or because they were removed by another mean). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -332,6 +346,8 @@ def compute_refrac_period_violations( return None sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + fs = sorting_analyzer.sampling_frequency num_units = len(sorting_analyzer.unit_ids) num_segments = sorting_analyzer.get_num_segments() @@ -392,6 +408,7 @@ def compute_sliding_rp_violations( exclude_ref_period_below_ms=0.5, max_ref_period_ms=10, contamination_values=None, + periods=None, ): """ Compute sliding refractory period violations, a metric developed by IBL which computes @@ -417,6 +434,9 @@ def compute_sliding_rp_violations( Maximum refractory period to test in ms. contamination_values : 1d array or None, default: None The contamination values to test, If None, it is set to np.arange(0.5, 35, 0.5). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -431,6 +451,8 @@ def compute_sliding_rp_violations( """ duration = sorting_analyzer.get_total_duration() sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + if unit_ids is None: unit_ids = sorting_analyzer.unit_ids num_segs = sorting_analyzer.get_num_segments() @@ -486,7 +508,7 @@ class SlidingRPViolation(BaseMetric): } -def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None): +def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=None, periods=None): """ Compute synchrony metrics. Synchrony metrics represent the rate of occurrences of spikes at the exact same sample index, with synchrony sizes 2, 4 and 8. @@ -504,6 +526,9 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N ------- sync_spike_{X} : dict The synchrony metric for synchrony size X. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. References ---------- @@ -520,6 +545,7 @@ def compute_synchrony_metrics(sorting_analyzer, unit_ids=None, synchrony_sizes=N res = namedtuple("synchrony_metrics", [f"sync_spike_{size}" for size in synchrony_sizes]) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) if unit_ids is None: unit_ids = sorting.unit_ids @@ -556,7 +582,7 @@ class Synchrony(BaseMetric): } -def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95)): +def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percentiles=(5, 95), periods=None): """ Calculate firing range, the range between the 5th and 95th percentiles of the firing rates distribution computed in non-overlapping time bins. @@ -571,6 +597,9 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent The size of the bin in seconds. percentiles : tuple, default: (5, 95) The percentiles to compute. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -584,6 +613,8 @@ def compute_firing_ranges(sorting_analyzer, unit_ids=None, bin_size_s=5, percent sampling_frequency = sorting_analyzer.sampling_frequency bin_size_samples = int(bin_size_s * sampling_frequency) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) + if unit_ids is None: unit_ids = sorting.unit_ids @@ -635,6 +666,7 @@ def compute_amplitude_cv_metrics( percentiles=(5, 95), min_num_bins=10, amplitude_extension="spike_amplitudes", + periods=None, ): """ Calculate coefficient of variation of spike amplitudes within defined temporal bins. @@ -658,6 +690,8 @@ def compute_amplitude_cv_metrics( the median and range are set to NaN. amplitude_extension : str, default: "spike_amplitudes" The name of the extension to load the amplitudes from. "spike_amplitudes" or "amplitude_scalings". + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) Returns ------- @@ -683,7 +717,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - amps = sorting_analyzer.get_extension(amplitude_extension).get_data() + amps = sorting_analyzer.get_extension(amplitude_extension).get_data(periods=periods) # precompute segment slice segment_slices = [] @@ -752,6 +786,7 @@ def compute_amplitude_cutoffs( num_histogram_bins=500, histogram_smoothing_value=3, amplitudes_bins_min_ratio=5, + periods=None, ): """ Calculate approximate fraction of spikes missing from a distribution of amplitudes. @@ -770,6 +805,9 @@ def compute_amplitude_cutoffs( The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -805,7 +843,7 @@ def compute_amplitude_cutoffs( invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -837,7 +875,7 @@ class AmplitudeCutoff(BaseMetric): depend_on = ["spike_amplitudes|amplitude_scalings"] -def compute_amplitude_medians(sorting_analyzer, unit_ids=None): +def compute_amplitude_medians(sorting_analyzer, unit_ids=None, periods=None): """ Compute median of the amplitude distributions (in absolute value). @@ -847,6 +885,9 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): A SortingAnalyzer object. unit_ids : list or None List of unit ids to compute the amplitude medians. If None, all units are used. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -865,7 +906,7 @@ def compute_amplitude_medians(sorting_analyzer, unit_ids=None): all_amplitude_medians = {} amplitude_extension = sorting_analyzer.get_extension("spike_amplitudes") - amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = amplitude_extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: all_amplitude_medians[unit_id] = np.median(amplitudes_by_units[unit_id]) @@ -882,7 +923,9 @@ class AmplitudeMedian(BaseMetric): depend_on = ["spike_amplitudes"] -def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100): +def compute_noise_cutoffs( + sorting_analyzer, unit_ids=None, high_quantile=0.25, low_quantile=0.1, n_bins=100, periods=None +): """ A metric to determine if a unit's amplitude distribution is cut off as it approaches zero, without assuming a Gaussian distribution. @@ -906,6 +949,9 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l Quantile of the amplitude range below which values are treated as "low" (e.g. 0.1 = lower 10%), the test region. n_bins: int, default: 100 The number of bins to use to compute the amplitude histogram. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. Returns ------- @@ -934,7 +980,7 @@ def compute_noise_cutoffs(sorting_analyzer, unit_ids=None, high_quantile=0.25, l invert_amplitudes = True extension = sorting_analyzer.get_extension("amplitude_scalings") - amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True) + amplitudes_by_units = extension.get_data(outputs="by_unit", concatenated=True, periods=periods) for unit_id in unit_ids: amplitudes = amplitudes_by_units[unit_id] @@ -972,6 +1018,7 @@ def compute_drift_metrics( min_fraction_valid_intervals=0.5, min_num_bins=2, return_positions=False, + periods=None, ): """ Compute drifts metrics using estimated spike locations. @@ -1006,6 +1053,9 @@ def compute_drift_metrics( min_num_bins : int, default: 2 Minimum number of bins required to return a valid metric value. In case there are less bins, the metric values are set to NaN. + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. return_positions : bool, default: False If True, median positions are returned (for debugging). @@ -1032,8 +1082,7 @@ def compute_drift_metrics( unit_ids = sorting.unit_ids spike_locations_ext = sorting_analyzer.get_extension("spike_locations") - spike_locations = spike_locations_ext.get_data() - # spike_locations_by_unit = spike_locations_ext.get_data(outputs="by_unit") + spike_locations = spike_locations_ext.get_data(periods=periods) spikes = sorting.to_spike_vector() spike_locations_by_unit = {} for unit_id in unit_ids: @@ -1145,12 +1194,14 @@ class Drift(BaseMetric): depend_on = ["spike_locations"] +# TODO def compute_sd_ratio( sorting_analyzer: SortingAnalyzer, unit_ids=None, censored_period_ms: float = 4.0, correct_for_drift: bool = True, correct_for_template_itself: bool = True, + periods=None, **kwargs, ): """ @@ -1173,6 +1224,9 @@ def compute_sd_ratio( correct_for_template_itself : bool, default: True If true, will take into account that the template itself impacts the standard deviation of the noise, and will make a rough estimation of what that impact is (and remove it). + periods : array of unit_period_dtype | None, default: None + Periods (segment_index, start_sample_index, end_sample_index, unit_index) + on which to compute the metric. If None, the entire recording duration is used. **kwargs : dict, default: {} Keyword arguments for computing spike amplitudes and extremum channel. @@ -1189,6 +1243,7 @@ def compute_sd_ratio( job_kwargs = fix_job_kwargs(job_kwargs) sorting = sorting_analyzer.sorting + sorting = sorting.select_periods(periods=periods) censored_period = int(round(censored_period_ms * 1e-3 * sorting_analyzer.sampling_frequency)) if unit_ids is None: @@ -1201,7 +1256,7 @@ def compute_sd_ratio( ) return {unit_id: np.nan for unit_id in unit_ids} - spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data() + spike_amplitudes = sorting_analyzer.get_extension("spike_amplitudes").get_data(periods=periods) if not HAVE_NUMBA: warnings.warn( diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index ab7ae9e7b5..75e41620f4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -130,7 +130,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan recording._properties["contact_vector"][idx][1] = x[idx] # generate random bad channel locations - bad_channel_indexes = rng.choice(num_channels, rng.randint(1, int(num_channels / 5)), replace=False) + bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) bad_channel_ids = recording.channel_ids[bad_channel_indexes] # Run SI and IBL interpolation and check against eachother