Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3dc5729
Test IBL extractors tests failing for PI update
alejoe91 Dec 29, 2025
d1a0532
Merge branch 'main' of github.com:SpikeInterface/spikeinterface
alejoe91 Jan 6, 2026
7279b67
wip
alejoe91 Jan 7, 2026
1962f21
Fix test for base sorting and propagate to basevector extension
alejoe91 Jan 7, 2026
528c82b
Fix tests in quailty metrics
alejoe91 Jan 8, 2026
775dda7
Fix retrieval of spikevector features
alejoe91 Jan 8, 2026
bb46f27
Update src/spikeinterface/core/sorting_tools.py
alejoe91 Jan 13, 2026
121a0b1
Apply suggestion from @chrishalcrow
alejoe91 Jan 13, 2026
cbf3213
refactor presence ratio and drift metrics to use periods properly
alejoe91 Jan 13, 2026
4409aa5
Fix rp_violations
alejoe91 Jan 13, 2026
71f8668
implement firing range and fix drift
alejoe91 Jan 13, 2026
1ea0d68
fix naming issue
alejoe91 Jan 13, 2026
a86c2d3
remove solved todos
alejoe91 Jan 13, 2026
d8e1f90
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jan 13, 2026
3f93f97
Implement select_segment_periods in core
alejoe91 Jan 13, 2026
cd85456
remove utils
alejoe91 Jan 13, 2026
7a42fe3
rebase on #4316
alejoe91 Jan 13, 2026
4f754cb
Merge with main
alejoe91 Jan 14, 2026
cbc0986
Fix import
alejoe91 Jan 14, 2026
56b672e
Merge branch 'select_sorting_periods_core' into select_sorting_periods
alejoe91 Jan 14, 2026
046430e
fix import
alejoe91 Jan 14, 2026
bb86253
Add misc_metric changes
alejoe91 Jan 14, 2026
50f33f0
fix tests
alejoe91 Jan 14, 2026
80bc50f
Change base_period_dtype order and fix select_sorting_periods array i…
alejoe91 Jan 15, 2026
4c8fa23
fix conflicts
alejoe91 Jan 15, 2026
e1f5bab
Merge metrics implementations
alejoe91 Jan 15, 2026
96e6a53
fix tests
alejoe91 Jan 15, 2026
3198911
Fix generation of bins
alejoe91 Jan 15, 2026
87fbe9a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jan 16, 2026
7446a43
Use cached get_spike_vector_to_indices
alejoe91 Jan 16, 2026
873a687
Solve conflicts
alejoe91 Jan 16, 2026
51e906a
Fix error in merging
alejoe91 Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
from .recording_tools import get_noise_levels
from .template import Templates
from .sorting_tools import random_spikes_selection
from .sorting_tools import random_spikes_selection, select_sorting_periods_mask
from .job_tools import fix_job_kwargs, split_job_kwargs


Expand Down Expand Up @@ -1331,6 +1331,21 @@ class BaseSpikeVectorExtension(AnalyzerExtension):
need_backward_compatibility_on_load = False
nodepipeline_variables = [] # to be defined in subclass

def __init__(self, sorting_analyzer):
super().__init__(sorting_analyzer)
self._segment_slices = None

@property
def segment_slices(self):
if self._segment_slices is None:
segment_slices = []
spikes = self.sorting_analyzer.sorting.to_spike_vector()
for segment_index in range(self.sorting_analyzer.get_num_segments()):
i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1])
segment_slices.append(slice(i0, i1))
self._segment_slices = segment_slices
return self._segment_slices

def _set_params(self, **kwargs):
params = kwargs.copy()
return params
Expand Down Expand Up @@ -1369,7 +1384,7 @@ def _run(self, verbose=False, **job_kwargs):
for d, name in zip(data, data_names):
self.data[name] = d

def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True):
def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, periods=None, copy=True):
"""
Return extension data. If the extension computes more than one `nodepipeline_variables`,
the `return_data_name` is used to specify which one to return.
Expand All @@ -1383,13 +1398,15 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
return_data_name : str | None, default: None
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
the first one is returned.
periods : array of unit_period dtype, default: None
Optional periods (segment_index, start_sample_index, end_sample_index, unit_index) to slice output data
copy : bool, default: True
Whether to return a copy of the data (only for outputs="numpy")

Returns
-------
numpy.ndarray | dict
The
The requested data in numpy or by unit format.
"""
from spikeinterface.core.sorting_tools import spike_vector_to_indices

Expand All @@ -1404,6 +1421,14 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"

all_data = self.data[return_data_name]
keep_mask = None
if periods is not None:
keep_mask = select_sorting_periods_mask(
self.sorting_analyzer.sorting,
periods,
)
all_data = all_data[keep_mask]

if outputs == "numpy":
if copy:
return all_data.copy() # return a copy to avoid modification
Expand All @@ -1412,6 +1437,8 @@ def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None,
elif outputs == "by_unit":
unit_ids = self.sorting_analyzer.unit_ids
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
if keep_mask is not None:
spike_vector = spike_vector[keep_mask]
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
data_by_units = {}
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,26 @@ def time_slice(self, start_time: float | None, end_time: float | None) -> BaseSo

return self.frame_slice(start_frame=start_frame, end_frame=end_frame)

def select_periods(self, periods):
"""
Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
periods : numpy.array of unit_period_dtype
Period (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting.

Returns
-------
BaseSorting
A new sorting object with only samples between start_sample_index and end_sample_index
for the given segment_index.
"""
from spikeinterface.core.sorting_tools import select_sorting_periods

return select_sorting_periods(self, periods)

def split_by(self, property="group", outputs="dict"):
"""
Splits object based on a certain property (e.g. "group")
Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@
("segment_index", "int64"),
]


spike_peak_dtype = base_peak_dtype + [
("unit_index", "int64"),
]

base_period_dtype = [
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move all dtypes to base

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in #4314

("start_sample_index", "int64"),
("end_sample_index", "int64"),
("segment_index", "int64"),
]

unit_period_dtype = base_period_dtype + [
("unit_index", "int64"),
]


class PipelineNode:

Expand Down
77 changes: 77 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,83 @@ def random_spikes_selection(
return random_spikes_indices


def select_sorting_periods_mask(sorting: BaseSorting, periods):
"""
Returns a boolean mask for the spikes in the sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
sorting : BaseSorting
The sorting object.
periods : numpy.array of unit_period_dtype
Periods (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting.

Returns
-------
numpy.array
A boolean mask of the spikes in the sorting object, with True for spikes within the specified periods.
"""
spike_vector = sorting.to_spike_vector()
spike_vector_list = sorting.to_spike_vector(concatenated=False)
keep_mask = np.zeros(len(spike_vector), dtype=bool)
all_global_indices = spike_vector_to_indices(spike_vector_list, unit_ids=sorting.unit_ids, absolute_index=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not using self.sorting_analyzer.sorting.get_spike_vector_to_indices(), and thus make use of a possible cache?

for segment_index in range(sorting.get_num_segments()):
global_indices_segment = all_global_indices[segment_index]
# filter periods by segment
periods_in_segment = periods[periods["segment_index"] == segment_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming not too many periods, these masks would be fine. Otherwise, we'll need to optimize

for unit_index, unit_id in enumerate(sorting.unit_ids):
# filter by unit index
periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index]
global_indices = global_indices_segment[unit_id]
spiketrains = spike_vector[global_indices]["sample_index"]
if len(periods_for_unit) > 0:
for period in periods_for_unit:
mask = (spiketrains >= period["start_sample_index"]) & (spiketrains < period["end_sample_index"])
keep_mask[global_indices[mask]] = True
return keep_mask


def select_sorting_periods(sorting: BaseSorting, periods) -> BaseSorting:
"""
Returns a new sorting object, restricted to the given periods of dtype unit_period_dtype.

Parameters
----------
S
periods : numpy.array of unit_period_dtype
Periods (segment_index, start_sample_index, end_sample_index, unit_index)
on which to restrict the sorting.

Returns
-------
BaseSorting
A new sorting object with only samples between start_sample_index and end_sample_index
for the given segment_index.
"""
from spikeinterface.core.numpyextractors import NumpySorting
from spikeinterface.core.node_pipeline import unit_period_dtype

if periods is not None:
if not isinstance(periods, np.ndarray):
periods = np.array([periods], dtype=unit_period_dtype)
required = set(np.dtype(unit_period_dtype).names)
if not required.issubset(periods.dtype.names):
raise ValueError(f"Period must have the following fields: {required}")

spike_vector = sorting.to_spike_vector()
keep_mask = select_sorting_periods_mask(sorting, periods)
sliced_spike_vector = spike_vector[keep_mask]

sorting = NumpySorting(
sliced_spike_vector, sampling_frequency=sorting.sampling_frequency, unit_ids=sorting.unit_ids
)
sorting.copy_metadata(sorting)
return sorting
else:
return sorting


### MERGING ZONE ###
def apply_merges_to_sorting(
sorting: BaseSorting,
Expand Down
69 changes: 63 additions & 6 deletions src/spikeinterface/core/tests/test_basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
but check only for BaseRecording general methods.
"""

import shutil
from pathlib import Path

import time
import numpy as np
import pytest
from numpy.testing import assert_raises
Expand All @@ -17,15 +15,15 @@
SharedMemorySorting,
NpzFolderSorting,
NumpyFolderSorting,
generate_ground_truth_recording,
generate_sorting,
create_sorting_npz,
generate_sorting,
load,
)
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core.testing import check_sorted_arrays_equal, check_sortings_equal
from spikeinterface.core.generate import generate_sorting

from spikeinterface.core import generate_recording, generate_ground_truth_recording
from spikeinterface.core.node_pipeline import unit_period_dtype


def test_BaseSorting(create_cache_folder):
Expand Down Expand Up @@ -226,7 +224,66 @@ def test_time_slice():
)


def test_select_periods():
sampling_frequency = 10_000.0
duration = 1_000
num_samples = int(sampling_frequency * duration)
num_units = 1000
sorting = generate_sorting(
durations=[duration, duration], sampling_frequency=sampling_frequency, num_units=num_units
)

rng = np.random.default_rng()

# number of random periods
n_periods = 10_000
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duration, num_units and n_periods are all quite large for a test. Is it slow??

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should put a small number

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

takes about 1 second

# generate random periods
segment_indices = rng.integers(0, sorting.get_num_segments(), n_periods)
start_samples = rng.integers(0, num_samples, n_periods)
durations = rng.integers(100, 100_000, n_periods)
end_samples = start_samples + durations
valid_periods = end_samples < num_samples
segment_indices = segment_indices[valid_periods]
start_samples = start_samples[valid_periods]
end_samples = end_samples[valid_periods]
unit_index = rng.integers(0, num_units - 1, len(segment_indices))

periods = np.zeros(len(segment_indices), dtype=unit_period_dtype)
periods["segment_index"] = segment_indices
periods["start_sample_index"] = start_samples
periods["end_sample_index"] = end_samples
periods["unit_index"] = unit_index
periods = np.sort(periods, order=["segment_index", "start_sample_index"])

t_start = time.perf_counter()
sliced_sorting = sorting.select_periods(periods=periods)
t_stop = time.perf_counter()
elapsed = t_stop - t_start
print(f"select_periods took {elapsed:.2f} seconds for {len(periods)} periods")

# Check that all spikes in the sliced sorting are within the periods
for segment_index in range(sorting.get_num_segments()):
periods_in_segment = periods[periods["segment_index"] == segment_index]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we assume not too many periods, just to be sure that all these masks won't take too long?

for unit_index, unit_id in enumerate(sorting.unit_ids):
spiketrain = sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id)

periods_for_unit = periods_in_segment[periods_in_segment["unit_index"] == unit_index]
spiketrain_in_periods = []
for period in periods_for_unit:
start_sample = period["start_sample_index"]
end_sample = period["end_sample_index"]
spiketrain_in_periods.append(spiketrain[(spiketrain >= start_sample) & (spiketrain < end_sample)])
if len(spiketrain_in_periods) == 0:
spiketrain_in_periods = np.array([], dtype=spiketrain.dtype)
else:
spiketrain_in_periods = np.unique(np.concatenate(spiketrain_in_periods))

spiketrain_sliced = sliced_sorting.get_unit_spike_train(segment_index=segment_index, unit_id=unit_id)
assert len(spiketrain_in_periods) == len(spiketrain_sliced)


if __name__ == "__main__":
test_BaseSorting()
test_npy_sorting()
test_empty_sorting()
test_select_periods()
Loading
Loading