1111
1212import warnings
1313import numpy as np
14+ from collections import namedtuple
1415
15- from .sortinganalyzer import AnalyzerExtension , register_result_extension
16+ from .sortinganalyzer import SortingAnalyzer , AnalyzerExtension , register_result_extension
1617from .waveform_tools import extract_waveforms_to_single_buffer , estimate_templates_with_accumulator
1718from .recording_tools import get_noise_levels
1819from .template import Templates
1920from .sorting_tools import random_spikes_selection
21+ from .job_tools import fix_job_kwargs , split_job_kwargs
2022
2123
2224class ComputeRandomSpikes (AnalyzerExtension ):
@@ -752,8 +754,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
752754
753755 Parameters
754756 ----------
755- sorting_analyzer : SortingAnalyzer
756- A SortingAnalyzer object
757757 **kwargs : dict
758758 Additional parameters for the `spikeinterface.get_noise_levels()` function
759759
@@ -770,9 +770,6 @@ class ComputeNoiseLevels(AnalyzerExtension):
770770 need_job_kwargs = True
771771 need_backward_compatibility_on_load = True
772772
773- def __init__ (self , sorting_analyzer ):
774- AnalyzerExtension .__init__ (self , sorting_analyzer )
775-
776773 def _set_params (self , ** noise_level_params ):
777774 params = noise_level_params .copy ()
778775 return params
@@ -814,3 +811,147 @@ def _handle_backward_compatibility_on_load(self):
814811
815812register_result_extension (ComputeNoiseLevels )
816813compute_noise_levels = ComputeNoiseLevels .function_factory ()
814+
815+
816+ class BaseSpikeVectorExtension (AnalyzerExtension ):
817+ """
818+ Base class for spikevector-based extension, where the data is a numpy array with the same
819+ length as the spike vector.
820+ """
821+
822+ extension_name = None # to be defined in subclass
823+ need_recording = True
824+ use_nodepipeline = True
825+ need_job_kwargs = True
826+ need_backward_compatibility_on_load = False
827+ nodepipeline_variables = [] # to be defined in subclass
828+
829+ def _set_params (self , ** kwargs ):
830+ params = kwargs .copy ()
831+ return params
832+
833+ def _run (self , verbose = False , ** job_kwargs ):
834+ from spikeinterface .core .node_pipeline import run_node_pipeline
835+
836+ # TODO: should we save directly to npy in binary_folder format / or to zarr?
837+ # if self.sorting_analyzer.format == "binary_folder":
838+ # gather_mode = "npy"
839+ # extension_folder = self.sorting_analyzer.folder / "extenstions" / self.extension_name
840+ # gather_kwargs = {"folder": extension_folder}
841+ gather_mode = "memory"
842+ gather_kwargs = {}
843+
844+ job_kwargs = fix_job_kwargs (job_kwargs )
845+ nodes = self .get_pipeline_nodes ()
846+ data = run_node_pipeline (
847+ self .sorting_analyzer .recording ,
848+ nodes ,
849+ job_kwargs = job_kwargs ,
850+ job_name = self .extension_name ,
851+ gather_mode = gather_mode ,
852+ gather_kwargs = gather_kwargs ,
853+ verbose = False ,
854+ )
855+ if isinstance (data , tuple ):
856+ # this logic enables extensions to optionally compute additional data based on params
857+ assert len (data ) <= len (self .nodepipeline_variables ), "Pipeline produced more outputs than expected"
858+ else :
859+ data = (data ,)
860+ if len (self .nodepipeline_variables ) > len (data ):
861+ data_names = self .nodepipeline_variables [: len (data )]
862+ else :
863+ data_names = self .nodepipeline_variables
864+ for d , name in zip (data , data_names ):
865+ self .data [name ] = d
866+
867+ def _get_data (self , outputs = "numpy" , concatenated = False , return_data_name = None , copy = True ):
868+ """
869+ Return extension data. If the extension computes more than one `nodepipeline_variables`,
870+ the `return_data_name` is used to specify which one to return.
871+
872+ Parameters
873+ ----------
874+ outputs : "numpy" | "by_unit", default: "numpy"
875+ How to return the data, by default "numpy"
876+ concatenated : bool, default: False
877+ Whether to concatenate the data across segments.
878+ return_data_name : str | None, default: None
879+ The name of the data to return. If None and multiple `nodepipeline_variables` are computed,
880+ the first one is returned.
881+ copy : bool, default: True
882+ Whether to return a copy of the data (only for outputs="numpy")
883+
884+ Returns
885+ -------
886+ numpy.ndarray | dict
887+ The
888+ """
889+ from spikeinterface .core .sorting_tools import spike_vector_to_indices
890+
891+ if len (self .nodepipeline_variables ) == 1 :
892+ return_data_name = self .nodepipeline_variables [0 ]
893+ else :
894+ if return_data_name is None :
895+ return_data_name = self .nodepipeline_variables [0 ]
896+ else :
897+ assert (
898+ return_data_name in self .nodepipeline_variables
899+ ), f"return_data_name { return_data_name } not in nodepipeline_variables { self .nodepipeline_variables } "
900+
901+ all_data = self .data [return_data_name ]
902+ if outputs == "numpy" :
903+ if copy :
904+ return all_data .copy () # return a copy to avoid modification
905+ else :
906+ return all_data
907+ elif outputs == "by_unit" :
908+ unit_ids = self .sorting_analyzer .unit_ids
909+ spike_vector = self .sorting_analyzer .sorting .to_spike_vector (concatenated = False )
910+ spike_indices = spike_vector_to_indices (spike_vector , unit_ids , absolute_index = True )
911+ data_by_units = {}
912+ for segment_index in range (self .sorting_analyzer .sorting .get_num_segments ()):
913+ data_by_units [segment_index ] = {}
914+ for unit_id in unit_ids :
915+ inds = spike_indices [segment_index ][unit_id ]
916+ data_by_units [segment_index ][unit_id ] = all_data [inds ]
917+
918+ if concatenated :
919+ data_by_units_concatenated = {
920+ unit_id : np .concatenate ([data_in_segment [unit_id ] for data_in_segment in data_by_units .values ()])
921+ for unit_id in unit_ids
922+ }
923+ return data_by_units_concatenated
924+
925+ return data_by_units
926+ else :
927+ raise ValueError (f"Wrong .get_data(outputs={ outputs } ); possibilities are `numpy` or `by_unit`" )
928+
929+ def _select_extension_data (self , unit_ids ):
930+ keep_unit_indices = np .flatnonzero (np .isin (self .sorting_analyzer .unit_ids , unit_ids ))
931+
932+ spikes = self .sorting_analyzer .sorting .to_spike_vector ()
933+ keep_spike_mask = np .isin (spikes ["unit_index" ], keep_unit_indices )
934+
935+ new_data = dict ()
936+ for data_name in self .nodepipeline_variables :
937+ if self .data .get (data_name ) is not None :
938+ new_data [data_name ] = self .data [data_name ][keep_spike_mask ]
939+
940+ return new_data
941+
942+ def _merge_extension_data (
943+ self , merge_unit_groups , new_unit_ids , new_sorting_analyzer , keep_mask = None , verbose = False , ** job_kwargs
944+ ):
945+ new_data = dict ()
946+ for data_name in self .nodepipeline_variables :
947+ if self .data .get (data_name ) is not None :
948+ if keep_mask is None :
949+ new_data [data_name ] = self .data [data_name ].copy ()
950+ else :
951+ new_data [data_name ] = self .data [data_name ][keep_mask ]
952+
953+ return new_data
954+
955+ def _split_extension_data (self , split_units , new_unit_ids , new_sorting_analyzer , verbose = False , ** job_kwargs ):
956+ # splitting only changes random spikes assignments
957+ return self .data .copy ()
0 commit comments