Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
from .template import Templates
from .sorting_tools import random_spikes_selection
from .sorting_tools import random_spikes_selection, select_sorting_periods_mask, spike_vector_to_indices
from .job_tools import fix_job_kwargs, split_job_kwargs


Expand Down Expand Up @@ -1343,6 +1343,9 @@ class BaseSpikeVectorExtension(AnalyzerExtension):
need_backward_compatibility_on_load = False
nodepipeline_variables = [] # to be defined in subclass

def __init__(self, sorting_analyzer):
super().__init__(sorting_analyzer)

def _set_params(self, **kwargs):
params = kwargs.copy()
return params
Expand Down Expand Up @@ -1381,7 +1384,7 @@ def _run(self, verbose=False, **job_kwargs):
for d, name in zip(data, data_names):
self.data[name] = d

def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True):
def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True):
"""
Return extension data. If the extension computes more than one `nodepipeline_variables`,
the `return_data_name` is used to specify which one to return.
Expand All @@ -1395,13 +1398,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
return_data_name : str | None, default: None
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
the first one is returned.
periods : array of unit_period dtype, default: None
Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data
copy : bool, default: True
Whether to return a copy of the data (only for outputs="numpy")

Returns
-------
numpy.ndarray | dict
The
The requested data in numpy or by unit format.
"""

if len(self.nodepipeline_variables) == 1:
Expand All @@ -1415,15 +1420,29 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
keep_mask = None
if periods is not None:
keep_mask = select_sorting_periods_mask(
self.sorting_analyzer.sorting,
periods,
)
all_data = all_data[keep_mask]

if outputs == "numpy":
if copy:
return all_data.copy() # return a copy to avoid modification
else:
return all_data
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
# use the cache of indices
spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices()
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
if keep_mask is not None:
# since we are filtering spikes, we need to recompute the spike indices
spike_vector = spike_vector[keep_mask]
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
else:
# use the cache of indices
spike_indices = self.sorting_analyzer.sorting.get_spike_vector_to_indices()
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
data_by_units[segment_index] = {}
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]

base_period_dtype = [
("segment_index", "int64"),
("start_sample_index", "int64"),
("end_sample_index", "int64"),
("segment_index", "int64"),
]

unit_period_dtype = base_period_dtype + [
Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo

return self.frame_slice(start_frame=start_frame, end_frame=end_frame)

def select_periods(self, periods):
"""
Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
periods : numpy.array of unit_period_dtype
Period (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting.

Returns
-------
BaseSorting
A new sorting object with only samples between start_sample_index and end_sample_index
for the given segment_index.
"""
from spikeinterface.core.sorting_tools import select_sorting_periods

return select_sorting_periods(self, periods)

def split_by(self, property="group", outputs="dict"):
"""
Splits object based on a certain property (e.g. "group")
Expand Down
96 changes: 96 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,102 @@ def random_spikes_selection(
return random_spikes_indices


def select_sorting_periods_mask(sorting: BaseSorting, periods):
"""
Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
sorting : BaseSorting
The sorting object.
periods : numpy.array of unit_period_dtype
Periods (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting.

Returns
-------
numpy.array
A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods.
"""
spike_vector = sorting.to_spike_vector()
keep_mask = np.zeros(len(spike_vector), dtype=bool)
all_global_indices = sorting.get_spike_vector_to_indices()
for segment_index in range(sorting.get_num_segments()):
global_indices_segment = all_global_indices[segment_index]
# filter periods by segment
periods_in_segment = periods[periods["segment_index"] == segment_index]
for unit_index, unit_id in enumerate(sorting.unit_ids):
# filter by unit index
periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index]
global_indices = global_indices_segment[unit_id]
spiketrains = spike_vector[global_indices]["sample_index"]
if len(periods_for_unit) > 0:
for period in periods_for_unit:
mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"])
keep_mask[global_indices[mask]] = True
return keep_mask


def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting:
"""
Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
periods : numpy.ndarray
Periods (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting. Periods can be either a numpy array of unit_period_dtype
or an array with (num_periods, 4) shape. In the latter case, the fields are assumed to be
in the order: segment_index, start_sample_index, end_sample_index, unit_index.

Returns
-------
BaseSorting
A new sorting object with only samples between start_sample_index and end_sample_index
for the given segment_index.
"""
from spikeinterface.core.base import unit_period_dtype
from spikeinterface.core.numpyextractors import NumpySorting

if periods is not None:
if not isinstance(periods, np.ndarray):
raise ValueError("periods must be a numpy array")
if not periods.dtype == unit_period_dtype:
if periods.ndim != 2 or periods.shape[1] != 4:
raise ValueError(
"If periods is not of dtype unit_period_dtype, it must be a 2D array with shape (num_periods, 4)"
)
warnings.warn(
"periods is not of dtype unit_period_dtype. Assuming fields are in order: "
"(segment_index, start_sample_index, end_sample_index, unit_index).",
UserWarning,
)
# convert to structured array
periods_converted = np.empty(periods.shape[0], dtype=unit_period_dtype)
periods_converted["segment_index"] = periods[:, 0]
periods_converted["start_sample_index"] = periods[:, 1]
periods_converted["end_sample_index"] = periods[:, 2]
periods_converted["unit_index"] = periods[:, 3]
periods = periods_converted

required = set(np.dtype(unit_period_dtype).names)
if not required.issubset(periods.dtype.names):
raise ValueError(f"Period must have the following fields: {required}")

spike_vector = sorting.to_spike_vector()
keep_mask = select_sorting_periods_mask(sorting, periods)
sliced_spike_vector = spike_vector[keep_mask]

# important: we keep the original unit ids so the unit_index field in spike vector is still valid
sorting = NumpySorting(
sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids
)
sorting.copy_metadata(sorting)
return sorting
else:
return sorting


### MERGING ZONE ###
def apply_merges_to_sorting(
sorting: BaseSorting,
Expand Down
80 changes: 73 additions & 7 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
but check only for BaseRecording general methods.
"""

import shutil
from pathlib import Path

import time
import numpy as np
import pytest
from numpy.testing import assert_raises
Expand All @@ -17,15 +15,14 @@
SharedMemorySorting,
NpzFolderSorting,
NumpyFolderSorting,
generate_ground_truth_recording,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this in testing sorting ?

generate_sorting,
create_sorting_npz,
generate_sorting,
load,
)
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core.base import BaseExtractor, unit_period_dtype
from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal
from spikeinterface.core.generate import generate_sorting

from spikeinterface.core import generate_recording, generate_ground_truth_recording


def test_BaseSorting(create_cache_folder):
Expand Down Expand Up @@ -245,6 +242,74 @@ def test_time_slice():
)


def test_select_periods():
sampling_frequency = 10_000.0
duration = 100
num_samples = int(sampling_frequency * duration)
num_units = 1000
sorting = generate_sorting(
durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units
)

rng = np.random.default_rng()

# number of random periods
n_periods = 1_000
# generate random periods
segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods)
start_samples = rng.integers(0, num_samples, n_periods)
durations = rng.integers(100, 100_000, n_periods)
end_samples = start_samples + durations
valid_periods = end_samples < num_samples
segment_indices = segment_indices[valid_periods]
start_samples = start_samples[valid_periods]
end_samples = end_samples[valid_periods]
unit_index = rng.integers(0, num_units - 1, len(segment_indices))

periods = np.zeros(len(segment_indices), dtype=unit_period_dtype)
periods["segment_index"] = segment_indices
periods["start_sample_index"] = start_samples
periods["end_sample_index"] = end_samples
periods["unit_index"] = unit_index
periods = np.sort(periods, order=["segment_index", "start_sample_index"])

t_start = time.perf_counter()
sliced_sorting = sorting.select_periods(periods=periods)
t_stop = time.perf_counter()
elapsed = t_stop - t_start
print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods")

# Check that all spikes in the sliced sorting are within the periods
for segment_index in range(sorting.get_num_segments()):
periods_in_segment = periods[periods["segment_index"] == segment_index]
for unit_index, unit_id in enumerate(sorting.unit_ids):
spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id)

periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index]
spiketrain_in_periods = []
for period in periods_for_unit:
start_sample = period["start_sample_index"]
end_sample = period["end_sample_index"]
spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)])
if len(spiketrain_in_periods) == 0:
spiketrain_in_periods = np.array([], dtype=spiketrain.dtype)
else:
spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods))

spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id)
assert len(spiketrain_in_periods) == len(spiketrain_sliced)

# now test with input as numpy array with shape (n_periods, 4)
periods_array = np.zeros((len(periods), 4), dtype="int64")
periods_array[:, 0] = periods["segment_index"]
periods_array[:, 1] = periods["start_sample_index"]
periods_array[:, 2] = periods["end_sample_index"]
periods_array[:, 3] = periods["unit_index"]

sliced_sorting_array = sorting.select_periods(periods=periods_array)
np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector())


if __name__ == "__main__":
import tempfile

Expand All @@ -254,3 +319,4 @@ def test_time_slice():
test_BaseSorting(cache_folder)
test_npy_sorting()
test_empty_sorting()
test_select_periods()
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan
recording._properties["contact_vector"][idx][1] = x[idx]

# generate random bad channel locations
bad_channel_indexes = rng.choice(num_channels, rng.randint(1, int(num_channels / 5)), replace=False)
bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False)
bad_channel_ids = recording.channel_ids[bad_channel_indexes]

# Run SI and IBL interpolation and check against eachother
Expand Down