Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/benchmark_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype


class MatchingBenchmark(Benchmark):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/benchmark/benchmark_peak_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from .benchmark_base import Benchmark, BenchmarkStudy
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from .benchmark_plot_tools import fit_sigmoid, sigmoid

Expand Down
25 changes: 25 additions & 0 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,31 @@
from .job_tools import _shared_job_kwargs_doc


# base dtypes used throughout spikeinterface
base_peak_dtype = [
("sample_index", "int64"),
("channel_index", "int64"),
("amplitude", "float64"),
("segment_index", "int64"),
]

spike_peak_dtype = base_peak_dtype + [
("unit_index", "int64"),
]

minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]

base_period_dtype = [
("start_sample_index", "int64"),
("end_sample_index", "int64"),
("segment_index", "int64"),
]

unit_period_dtype = base_period_dtype + [
("unit_index", "int64"),
]


class BaseExtractor:
"""
Base class for Recording/Sorting
Expand Down
5 changes: 1 addition & 4 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@

import numpy as np

from .base import BaseExtractor, BaseSegment
from .base import BaseExtractor, BaseSegment, minimum_spike_dtype
from .waveform_tools import has_exceeding_spikes


minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]


class BaseSorting(BaseExtractor):
"""
Abstract class representing several segment several units and relative spiketrains.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from typing import Literal, Optional
from math import ceil

from .base import minimum_spike_dtype
from .basesorting import SpikeVectorSortingSegment
from .numpyextractors import NumpySorting
from .basesorting import minimum_spike_dtype

from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe

Expand Down
14 changes: 1 addition & 13 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,12 @@

import numpy as np

from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype
from spikeinterface.core import BaseRecording, get_chunk_with_margin
from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc
from spikeinterface.core import get_channel_distances


base_peak_dtype = [
("sample_index", "int64"),
("channel_index", "int64"),
("amplitude", "float64"),
("segment_index", "int64"),
]


spike_peak_dtype = base_peak_dtype + [
("unit_index", "int64"),
]


class PipelineNode:

# If False (general case) then compute(traces_chunk, *node_input_args)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BaseSnippets,
BaseSnippetsSegment,
)
from .basesorting import minimum_spike_dtype
from .base import minimum_spike_dtype
from .core_tools import make_shared_array
from .recording_tools import write_memory_recording
from multiprocessing.shared_memory import SharedMemory
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import shutil

from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording
from spikeinterface.core.base import spike_peak_dtype
from spikeinterface.core.job_tools import divide_recording_into_chunks


# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
run_node_pipeline,
Expand All @@ -14,7 +16,6 @@
PipelineNode,
ExtractDenseWaveforms,
sorting_to_peaks,
spike_peak_dtype,
)


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_numpy_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
generate_recording,
)

from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype
from spikeinterface.core.testing import check_sortings_equal


Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from probeinterface import ProbeGroup

from .base import minimum_spike_dtype
from .baserecording import BaseRecording, BaseRecordingSegment
from .basesorting import BaseSorting, SpikeVectorSortingSegment, minimum_spike_dtype
from .basesorting import BaseSorting, SpikeVectorSortingSegment
from .core_tools import define_function_from_class, check_json
from .job_tools import split_job_kwargs
from .core_tools import is_path_remote
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
)


from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype


job_kwargs = dict(n_jobs=2, progress_bar=True, chunk_duration="1s")
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/silence_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import numpy as np

from spikeinterface.core.base import base_peak_dtype
from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.node_pipeline import PeakDetector
from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording
from spikeinterface.preprocessing.rectify import RectifyRecording
from spikeinterface.preprocessing.common_reference import CommonReferenceRecording
from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording
from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype
import numpy as np


class DetectThresholdCrossing(PeakDetector):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from spikeinterface.core.job_tools import fix_job_kwargs

from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_shuffled_recording_slices,
_set_optimal_chunk_size,
)
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype
from spikeinterface.core import compute_sparsity


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from spikeinterface.core.job_tools import fix_job_kwargs

from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore, whiten
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import cache_preprocessing, clean_cache_preprocessing

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
else:
HAVE_HDBSCAN = False

from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.sortingcomponents.clustering.merging_tools import merge_peak_labels_from_templates
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/clustering/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from spikeinterface.core import Templates, estimate_templates, fix_job_kwargs
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.base import minimum_spike_dtype

# TODO find a way to attach a a sparse_mask to a given features (waveforms, pca, tsvd ....)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import numpy as np


from spikeinterface.core.base import base_peak_dtype
from spikeinterface.core.baserecording import BaseRecording
from spikeinterface.core.node_pipeline import (
PeakDetector,
WaveformsNode,
ExtractSparseWaveforms,
base_peak_dtype,
)

expanded_base_peak_dtype = np.dtype(base_peak_dtype + [("iteration", "int8")])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@

from __future__ import annotations

import importlib.util
import numpy as np


from spikeinterface.core.node_pipeline import (
PeakDetector,
base_peak_dtype,
)

from spikeinterface.core.base import base_peak_dtype
from spikeinterface.core.node_pipeline import PeakDetector
from spikeinterface.core.recording_tools import get_channel_distances, get_random_data_chunks
from spikeinterface.postprocessing.localization_tools import get_convolution_weights

import importlib.util

numba_spec = importlib.util.find_spec("numba")
if numba_spec is not None:
Expand Down