diff --git a/src/spikeinterface/benchmark/benchmark_matching.py b/src/spikeinterface/benchmark/benchmark_matching.py index 02e2146b5a..9a02a8e3cd 100644 --- a/src/spikeinterface/benchmark/benchmark_matching.py +++ b/src/spikeinterface/benchmark/benchmark_matching.py @@ -12,7 +12,7 @@ import numpy as np from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype class MatchingBenchmark(Benchmark): diff --git a/src/spikeinterface/benchmark/benchmark_peak_detection.py b/src/spikeinterface/benchmark/benchmark_peak_detection.py index ef25e7dd29..c687a81312 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/benchmark/benchmark_peak_detection.py @@ -13,7 +13,7 @@ import numpy as np from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs from .benchmark_base import Benchmark, BenchmarkStudy -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.sortinganalyzer import create_sorting_analyzer from .benchmark_plot_tools import fit_sigmoid, sigmoid diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6fe49b8606..3505853835 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -27,6 +27,31 @@ from .job_tools import _shared_job_kwargs_doc +# base dtypes used throughout spikeinterface +base_peak_dtype = [ + ("sample_index", "int64"), + ("channel_index", "int64"), + ("amplitude", "float64"), + ("segment_index", "int64"), +] + +spike_peak_dtype = base_peak_dtype + [ + ("unit_index", "int64"), +] + +minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] + +base_period_dtype = [ + ("start_sample_index", "int64"), + ("end_sample_index", "int64"), + ("segment_index", "int64"), +] + +unit_period_dtype = base_period_dtype + [ + ("unit_index", "int64"), +] + + class BaseExtractor: """ Base class for Recording/Sorting diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 98159fb646..a28a2dbb66 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -5,13 +5,10 @@ import numpy as np -from .base import BaseExtractor, BaseSegment +from .base import BaseExtractor, BaseSegment, minimum_spike_dtype from .waveform_tools import has_exceeding_spikes -minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")] - - class BaseSorting(BaseExtractor): """ Abstract class representing several segment several units and relative spiketrains. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index ce613113b7..eddbe318a6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -5,9 +5,9 @@ from typing import Literal, Optional from math import ceil +from .base import minimum_spike_dtype from .basesorting import SpikeVectorSortingSegment from .numpyextractors import NumpySorting -from .basesorting import minimum_spike_dtype from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index 71654a67b4..1609f11d17 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -10,24 +10,12 @@ import numpy as np +from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype from spikeinterface.core import BaseRecording, get_chunk_with_margin from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc from spikeinterface.core import get_channel_distances -base_peak_dtype = [ - ("sample_index", "int64"), - ("channel_index", "int64"), - ("amplitude", "float64"), - ("segment_index", "int64"), -] - - -spike_peak_dtype = base_peak_dtype + [ - ("unit_index", "int64"), -] - - class PipelineNode: # If False (general case) then compute(traces_chunk, *node_input_args) diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 019759797b..f14b1a879a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -12,7 +12,7 @@ BaseSnippets, BaseSnippetsSegment, ) -from .basesorting import minimum_spike_dtype +from .base import minimum_spike_dtype from .core_tools import make_shared_array from .recording_tools import write_memory_recording from multiprocessing.shared_memory import SharedMemory diff --git a/src/spikeinterface/core/tests/test_node_pipeline.py b/src/spikeinterface/core/tests/test_node_pipeline.py index 028eaecf12..e8b32dc99d 100644 --- a/src/spikeinterface/core/tests/test_node_pipeline.py +++ b/src/spikeinterface/core/tests/test_node_pipeline.py @@ -4,8 +4,10 @@ import shutil from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording +from spikeinterface.core.base import spike_peak_dtype from spikeinterface.core.job_tools import divide_recording_into_chunks + # from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.core.node_pipeline import ( run_node_pipeline, @@ -14,7 +16,6 @@ PipelineNode, ExtractDenseWaveforms, sorting_to_peaks, - spike_peak_dtype, ) diff --git a/src/spikeinterface/core/tests/test_numpy_extractors.py b/src/spikeinterface/core/tests/test_numpy_extractors.py index 50426f6222..6eb9918b66 100644 --- a/src/spikeinterface/core/tests/test_numpy_extractors.py +++ b/src/spikeinterface/core/tests/test_numpy_extractors.py @@ -13,7 +13,7 @@ generate_recording, ) -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.testing import check_sortings_equal diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 4096740637..c0c9f9b5d7 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -7,8 +7,9 @@ from probeinterface import ProbeGroup +from .base import minimum_spike_dtype from .baserecording import BaseRecording, BaseRecordingSegment -from .basesorting import BaseSorting, SpikeVectorSortingSegment, minimum_spike_dtype +from .basesorting import BaseSorting, SpikeVectorSortingSegment from .core_tools import define_function_from_class, check_json from .job_tools import split_job_kwargs from .core_tools import is_path_remote diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index c0dd6c6033..fb764dac78 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -53,7 +53,7 @@ ) -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s") diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py index b1ae00b64c..f8323b2e78 100644 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ b/src/spikeinterface/preprocessing/silence_artifacts.py @@ -2,15 +2,15 @@ import numpy as np +from spikeinterface.core.base import base_peak_dtype from spikeinterface.core.core_tools import define_function_handling_dict_from_class +from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels +from spikeinterface.core.node_pipeline import PeakDetector from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording from spikeinterface.preprocessing.rectify import RectifyRecording from spikeinterface.preprocessing.common_reference import CommonReferenceRecording from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype -import numpy as np class DetectThresholdCrossing(PeakDetector): diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index e09036fd8c..df1e98b57b 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -15,7 +15,7 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index a43ade9a85..5351a4d62d 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -15,7 +15,7 @@ get_shuffled_recording_slices, _set_optimal_chunk_size, ) -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core import compute_sparsity diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index a4306dd431..373305c336 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -18,7 +18,7 @@ from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing diff --git a/src/spikeinterface/sortingcomponents/clustering/random_projections.py b/src/spikeinterface/sortingcomponents/clustering/random_projections.py index ac059e22ea..3586977440 100644 --- a/src/spikeinterface/sortingcomponents/clustering/random_projections.py +++ b/src/spikeinterface/sortingcomponents/clustering/random_projections.py @@ -13,7 +13,7 @@ else: HAVE_HDBSCAN = False -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.waveform_tools import estimate_templates from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser diff --git a/src/spikeinterface/sortingcomponents/clustering/tools.py b/src/spikeinterface/sortingcomponents/clustering/tools.py index 406aee1bd4..1f0fb652ff 100644 --- a/src/spikeinterface/sortingcomponents/clustering/tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/tools.py @@ -5,7 +5,7 @@ import numpy as np from spikeinterface.core import Templates, estimate_templates, fix_job_kwargs -from spikeinterface.core.basesorting import minimum_spike_dtype +from spikeinterface.core.base import minimum_spike_dtype # TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py index ecce5951c8..818c2a490f 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/iterative.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/iterative.py @@ -5,13 +5,12 @@ import numpy as np - +from spikeinterface.core.base import base_peak_dtype from spikeinterface.core.baserecording import BaseRecording from spikeinterface.core.node_pipeline import ( PeakDetector, WaveformsNode, ExtractSparseWaveforms, - base_peak_dtype, ) expanded_base_peak_dtype = np.dtype(base_peak_dtype + [("iteration", "int8")]) diff --git a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py index 46e0709d9e..c6acf22048 100644 --- a/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py +++ b/src/spikeinterface/sortingcomponents/peak_detection/matched_filtering.py @@ -2,18 +2,14 @@ from __future__ import annotations +import importlib.util import numpy as np - -from spikeinterface.core.node_pipeline import ( - PeakDetector, - base_peak_dtype, -) - +from spikeinterface.core.base import base_peak_dtype +from spikeinterface.core.node_pipeline import PeakDetector from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks from spikeinterface.postprocessing.localization_tools import get_convolution_weights -import importlib.util numba_spec = importlib.util.find_spec("numba") if numba_spec is not None: