Skip to content

Commit 66fc48a

Browse files
Add BaseSpikeVectorExtension (#4189)
Co-authored-by: Garcia Samuel <[email protected]>
1 parent 123e862 commit 66fc48a

File tree

11 files changed

+170
-308
lines changed

11 files changed

+170
-308
lines changed

src/spikeinterface/core/analyzer_extension_core.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111

1212
import warnings
1313
import numpy as np
14+
from collections import namedtuple
1415

15-
from .sortinganalyzer import AnalyzerExtension, register_result_extension
16+
from .sortinganalyzer import SortingAnalyzer, AnalyzerExtension, register_result_extension
1617
from .waveform_tools import extract_waveforms_to_single_buffer, estimate_templates_with_accumulator
1718
from .recording_tools import get_noise_levels
1819
from .template import Templates
1920
from .sorting_tools import random_spikes_selection
21+
from .job_tools import fix_job_kwargs, split_job_kwargs
2022

2123

2224
class ComputeRandomSpikes(AnalyzerExtension):
@@ -752,8 +754,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
752754
753755
Parameters
754756
----------
755-
sorting_analyzer : SortingAnalyzer
756-
A SortingAnalyzer object
757757
**kwargs : dict
758758
Additional parameters for the `spikeinterface.get_noise_levels()` function
759759
@@ -770,9 +770,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
770770
need_job_kwargs = True
771771
need_backward_compatibility_on_load = True
772772

773-
def __init__(self, sorting_analyzer):
774-
AnalyzerExtension.__init__(self, sorting_analyzer)
775-
776773
def _set_params(self, **noise_level_params):
777774
params = noise_level_params.copy()
778775
return params
@@ -814,3 +811,147 @@ def _handle_backward_compatibility_on_load(self):
814811

815812
register_result_extension(ComputeNoiseLevels)
816813
compute_noise_levels = ComputeNoiseLevels.function_factory()
814+
815+
816+
class BaseSpikeVectorExtension(AnalyzerExtension):
817+
"""
818+
Base class for spikevector-based extension, where the data is a numpy array with the same
819+
length as the spike vector.
820+
"""
821+
822+
extension_name = None # to be defined in subclass
823+
need_recording = True
824+
use_nodepipeline = True
825+
need_job_kwargs = True
826+
need_backward_compatibility_on_load = False
827+
nodepipeline_variables = [] # to be defined in subclass
828+
829+
def _set_params(self, **kwargs):
830+
params = kwargs.copy()
831+
return params
832+
833+
def _run(self, verbose=False, **job_kwargs):
834+
from spikeinterface.core.node_pipeline import run_node_pipeline
835+
836+
# TODO: should we save directly to npy in binary_folder format / or to zarr?
837+
# if self.sorting_analyzer.format == "binary_folder":
838+
# gather_mode = "npy"
839+
# extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name
840+
# gather_kwargs = {"folder": extension_folder}
841+
gather_mode = "memory"
842+
gather_kwargs = {}
843+
844+
job_kwargs = fix_job_kwargs(job_kwargs)
845+
nodes = self.get_pipeline_nodes()
846+
data = run_node_pipeline(
847+
self.sorting_analyzer.recording,
848+
nodes,
849+
job_kwargs=job_kwargs,
850+
job_name=self.extension_name,
851+
gather_mode=gather_mode,
852+
gather_kwargs=gather_kwargs,
853+
verbose=False,
854+
)
855+
if isinstance(data, tuple):
856+
# this logic enables extensions to optionally compute additional data based on params
857+
assert len(data) <= len(self.nodepipeline_variables), "Pipeline produced more outputs than expected"
858+
else:
859+
data = (data,)
860+
if len(self.nodepipeline_variables) > len(data):
861+
data_names = self.nodepipeline_variables[: len(data)]
862+
else:
863+
data_names = self.nodepipeline_variables
864+
for d, name in zip(data, data_names):
865+
self.data[name] = d
866+
867+
def _get_data(self, outputs="numpy", concatenated=False, return_data_name=None, copy=True):
868+
"""
869+
Return extension data. If the extension computes more than one `nodepipeline_variables`,
870+
the `return_data_name` is used to specify which one to return.
871+
872+
Parameters
873+
----------
874+
outputs : "numpy" | "by_unit", default: "numpy"
875+
How to return the data, by default "numpy"
876+
concatenated : bool, default: False
877+
Whether to concatenate the data across segments.
878+
return_data_name : str | None, default: None
879+
The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
880+
the first one is returned.
881+
copy : bool, default: True
882+
Whether to return a copy of the data (only for outputs="numpy")
883+
884+
Returns
885+
-------
886+
numpy.ndarray | dict
887+
The
888+
"""
889+
from spikeinterface.core.sorting_tools import spike_vector_to_indices
890+
891+
if len(self.nodepipeline_variables) == 1:
892+
return_data_name = self.nodepipeline_variables[0]
893+
else:
894+
if return_data_name is None:
895+
return_data_name = self.nodepipeline_variables[0]
896+
else:
897+
assert (
898+
return_data_name in self.nodepipeline_variables
899+
), f"return_data_name {return_data_name} not in nodepipeline_variables {self.nodepipeline_variables}"
900+
901+
all_data = self.data[return_data_name]
902+
if outputs == "numpy":
903+
if copy:
904+
return all_data.copy() # return a copy to avoid modification
905+
else:
906+
return all_data
907+
elif outputs == "by_unit":
908+
unit_ids = self.sorting_analyzer.unit_ids
909+
spike_vector = self.sorting_analyzer.sorting.to_spike_vector(concatenated=False)
910+
spike_indices = spike_vector_to_indices(spike_vector, unit_ids, absolute_index=True)
911+
data_by_units = {}
912+
for segment_index in range(self.sorting_analyzer.sorting.get_num_segments()):
913+
data_by_units[segment_index] = {}
914+
for unit_id in unit_ids:
915+
inds = spike_indices[segment_index][unit_id]
916+
data_by_units[segment_index][unit_id] = all_data[inds]
917+
918+
if concatenated:
919+
data_by_units_concatenated = {
920+
unit_id: np.concatenate([data_in_segment[unit_id] for data_in_segment in data_by_units.values()])
921+
for unit_id in unit_ids
922+
}
923+
return data_by_units_concatenated
924+
925+
return data_by_units
926+
else:
927+
raise ValueError(f"Wrong .get_data(outputs={outputs}); possibilities are `numpy` or `by_unit`")
928+
929+
def _select_extension_data(self, unit_ids):
930+
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))
931+
932+
spikes = self.sorting_analyzer.sorting.to_spike_vector()
933+
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)
934+
935+
new_data = dict()
936+
for data_name in self.nodepipeline_variables:
937+
if self.data.get(data_name) is not None:
938+
new_data[data_name] = self.data[data_name][keep_spike_mask]
939+
940+
return new_data
941+
942+
def _merge_extension_data(
943+
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
944+
):
945+
new_data = dict()
946+
for data_name in self.nodepipeline_variables:
947+
if self.data.get(data_name) is not None:
948+
if keep_mask is None:
949+
new_data[data_name] = self.data[data_name].copy()
950+
else:
951+
new_data[data_name] = self.data[data_name][keep_mask]
952+
953+
return new_data
954+
955+
def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
956+
# splitting only changes random spikes assignments
957+
return self.data.copy()

src/spikeinterface/core/node_pipeline.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def __init__(
317317
self.ms_after = ms_after
318318
self.nbefore = int(ms_before * recording.get_sampling_frequency() / 1000.0)
319319
self.nafter = int(ms_after * recording.get_sampling_frequency() / 1000.0)
320+
self.neighbours_mask = None
320321

321322

322323
class ExtractDenseWaveforms(WaveformsNode):
@@ -356,8 +357,6 @@ def __init__(
356357
ms_after=ms_after,
357358
return_output=return_output,
358359
)
359-
# this is a bad hack to differentiate in the child if the parents is dense or not.
360-
self.neighbours_mask = None
361360

362361
def get_trace_margin(self):
363362
return max(self.nbefore, self.nafter)
@@ -573,7 +572,7 @@ def run_node_pipeline(
573572
gather_mode : "memory" | "npy"
574573
How to gather the output of the nodes.
575574
gather_kwargs : dict
576-
OPtions to control the "gather engine". See GatherToMemory or GatherToNpy.
575+
Options to control the "gather engine". See GatherToMemory or GatherToNpy.
577576
squeeze_output : bool, default True
578577
If only one output node then squeeze the tuple
579578
folder : str | Path | None
@@ -784,7 +783,7 @@ def finalize_buffers(self, squeeze_output=False):
784783

785784
class GatherToNpy:
786785
"""
787-
Gather output of nodes into npy file and then open then as memmap.
786+
Gather output of nodes into npy file and then open them as memmap.
788787
789788
790789
The trick is:
@@ -891,6 +890,6 @@ def finalize_buffers(self, squeeze_output=False):
891890
return np.load(filename, mmap_mode="r")
892891

893892

894-
class GatherToHdf5:
893+
class GatherToZarr:
895894
pass
896895
# Fot me (sam) this is not necessary unless someone realy really want to use

src/spikeinterface/postprocessing/amplitude_scalings.py

Lines changed: 8 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,14 @@
33
import numpy as np
44

55
from spikeinterface.core import ChannelSparsity
6-
from spikeinterface.core.job_tools import ChunkRecordingExecutor, _shared_job_kwargs_doc, ensure_n_jobs, fix_job_kwargs
6+
from spikeinterface.core.template_tools import get_template_extremum_channel, get_dense_templates_array, _get_nbefore
7+
from spikeinterface.core.sortinganalyzer import register_result_extension
8+
from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension
79

8-
from spikeinterface.core.template_tools import get_template_extremum_channel
10+
from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, find_parent_of_type
911

10-
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension
1112

12-
from spikeinterface.core.node_pipeline import SpikeRetriever, PipelineNode, run_node_pipeline, find_parent_of_type
13-
14-
from spikeinterface.core.template_tools import get_dense_templates_array, _get_nbefore
15-
16-
17-
class ComputeAmplitudeScalings(AnalyzerExtension):
13+
class ComputeAmplitudeScalings(BaseSpikeVectorExtension):
1814
"""
1915
Computes the amplitude scalings from a SortingAnalyzer.
2016
@@ -55,31 +51,11 @@ class ComputeAmplitudeScalings(AnalyzerExtension):
5551
multi-linear regression model (with `sklearn.LinearRegression`). If False, each spike is fitted independently.
5652
delta_collision_ms: float, default: 2
5753
The maximum time difference in ms before and after a spike to gather colliding spikes.
58-
load_if_exists : bool, default: False
59-
Whether to load precomputed spike amplitudes, if they already exist.
60-
outputs: "concatenated" | "by_unit", default: "concatenated"
61-
How the output should be returned
62-
{}
63-
64-
Returns
65-
-------
66-
amplitude_scalings: np.array or list of dict
67-
The amplitude scalings.
68-
- If "concatenated" all amplitudes for all spikes and all units are concatenated
69-
- If "by_unit", amplitudes are returned as a list (for segments) of dictionaries (for units)
7054
"""
7155

7256
extension_name = "amplitude_scalings"
7357
depend_on = ["templates"]
74-
need_recording = True
75-
use_nodepipeline = True
7658
nodepipeline_variables = ["amplitude_scalings", "collision_mask"]
77-
need_job_kwargs = True
78-
79-
def __init__(self, sorting_analyzer):
80-
AnalyzerExtension.__init__(self, sorting_analyzer)
81-
82-
self.collisions = None
8359

8460
def _set_params(
8561
self,
@@ -90,46 +66,14 @@ def _set_params(
9066
handle_collisions=True,
9167
delta_collision_ms=2,
9268
):
93-
params = dict(
69+
return super()._set_params(
9470
sparsity=sparsity,
9571
max_dense_channels=max_dense_channels,
9672
ms_before=ms_before,
9773
ms_after=ms_after,
9874
handle_collisions=handle_collisions,
9975
delta_collision_ms=delta_collision_ms,
10076
)
101-
return params
102-
103-
def _select_extension_data(self, unit_ids):
104-
keep_unit_indices = np.flatnonzero(np.isin(self.sorting_analyzer.unit_ids, unit_ids))
105-
106-
spikes = self.sorting_analyzer.sorting.to_spike_vector()
107-
keep_spike_mask = np.isin(spikes["unit_index"], keep_unit_indices)
108-
109-
new_data = dict()
110-
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_spike_mask]
111-
if self.params["handle_collisions"]:
112-
new_data["collision_mask"] = self.data["collision_mask"][keep_spike_mask]
113-
return new_data
114-
115-
def _merge_extension_data(
116-
self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs
117-
):
118-
new_data = dict()
119-
120-
if keep_mask is None:
121-
new_data["amplitude_scalings"] = self.data["amplitude_scalings"].copy()
122-
if self.params["handle_collisions"]:
123-
new_data["collision_mask"] = self.data["collision_mask"].copy()
124-
else:
125-
new_data["amplitude_scalings"] = self.data["amplitude_scalings"][keep_mask]
126-
if self.params["handle_collisions"]:
127-
new_data["collision_mask"] = self.data["collision_mask"][keep_mask]
128-
129-
return new_data
130-
131-
def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs):
132-
return self.data.copy()
13377

13478
def _get_pipeline_nodes(self):
13579

@@ -141,6 +85,7 @@ def _get_pipeline_nodes(self):
14185
all_templates = get_dense_templates_array(self.sorting_analyzer, return_in_uV=return_in_uV)
14286
nbefore = _get_nbefore(self.sorting_analyzer)
14387
nafter = all_templates.shape[1] - nbefore
88+
templates_ext = self.sorting_analyzer.get_extension("templates")
14489

14590
# if ms_before / ms_after are set in params then the original templates are shorten
14691
if self.params["ms_before"] is not None:
@@ -155,7 +100,7 @@ def _get_pipeline_nodes(self):
155100
cut_out_after = int(self.params["ms_after"] * self.sorting_analyzer.sampling_frequency / 1000.0)
156101
assert (
157102
cut_out_after <= nafter
158-
), f"`ms_after` must be smaller than `ms_after` used in WaveformExractor: {we._params['ms_after']}"
103+
), f"`ms_after` must be smaller than `ms_after` used in templates: {templates_ext.params['ms_after']}"
159104
else:
160105
cut_out_after = nafter
161106

@@ -210,30 +155,6 @@ def _get_pipeline_nodes(self):
210155
nodes = [spike_retriever_node, amplitude_scalings_node]
211156
return nodes
212157

213-
def _run(self, verbose=False, **job_kwargs):
214-
job_kwargs = fix_job_kwargs(job_kwargs)
215-
nodes = self.get_pipeline_nodes()
216-
amp_scalings, collision_mask = run_node_pipeline(
217-
self.sorting_analyzer.recording,
218-
nodes,
219-
job_kwargs=job_kwargs,
220-
job_name="amplitude_scalings",
221-
gather_mode="memory",
222-
verbose=verbose,
223-
)
224-
self.data["amplitude_scalings"] = amp_scalings
225-
if self.params["handle_collisions"]:
226-
self.data["collision_mask"] = collision_mask
227-
# TODO: make collisions "global"
228-
# for collision in collisions:
229-
# collisions_dict.update(collision)
230-
# self.collisions = collisions_dict
231-
# # Note: collisions are note in _extension_data because they are not pickable. We only store the indices
232-
# self._extension_data["collisions"] = np.array(list(collisions_dict.keys()))
233-
234-
def _get_data(self):
235-
return self.data[f"amplitude_scalings"]
236-
237158

238159
register_result_extension(ComputeAmplitudeScalings)
239160
compute_amplitude_scalings = ComputeAmplitudeScalings.function_factory()

0 commit comments

Comments
 (0)