Skip to content

Commit 561c3c7

Browse files
authored
Move structured dtypes to base (#4314)
1 parent 9436839 commit 561c3c7

File tree

19 files changed

+50
-43
lines changed

19 files changed

+50
-43
lines changed

src/spikeinterface/benchmark/benchmark_matching.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
15-
from spikeinterface.core.basesorting import minimum_spike_dtype
15+
from spikeinterface.core.base import minimum_spike_dtype
1616

1717

1818
class MatchingBenchmark(Benchmark):

src/spikeinterface/benchmark/benchmark_peak_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import numpy as np
1414
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
1515
from .benchmark_base import Benchmark, BenchmarkStudy
16-
from spikeinterface.core.basesorting import minimum_spike_dtype
16+
from spikeinterface.core.base import minimum_spike_dtype
1717
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
1818
from .benchmark_plot_tools import fit_sigmoid, sigmoid
1919

src/spikeinterface/core/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,31 @@
2727
from .job_tools import _shared_job_kwargs_doc
2828

2929

30+
# base dtypes used throughout spikeinterface
31+
base_peak_dtype = [
32+
("sample_index", "int64"),
33+
("channel_index", "int64"),
34+
("amplitude", "float64"),
35+
("segment_index", "int64"),
36+
]
37+
38+
spike_peak_dtype = base_peak_dtype + [
39+
("unit_index", "int64"),
40+
]
41+
42+
minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
43+
44+
base_period_dtype = [
45+
("start_sample_index", "int64"),
46+
("end_sample_index", "int64"),
47+
("segment_index", "int64"),
48+
]
49+
50+
unit_period_dtype = base_period_dtype + [
51+
("unit_index", "int64"),
52+
]
53+
54+
3055
class BaseExtractor:
3156
"""
3257
Base class for Recording/Sorting

src/spikeinterface/core/basesorting.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,10 @@
55

66
import numpy as np
77

8-
from .base import BaseExtractor, BaseSegment
8+
from .base import BaseExtractor, BaseSegment, minimum_spike_dtype
99
from .waveform_tools import has_exceeding_spikes
1010

1111

12-
minimum_spike_dtype = [("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
13-
14-
1512
class BaseSorting(BaseExtractor):
1613
"""
1714
Abstract class representing several segment several units and relative spiketrains.

src/spikeinterface/core/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from typing import Literal, Optional
66
from math import ceil
77

8+
from .base import minimum_spike_dtype
89
from .basesorting import SpikeVectorSortingSegment
910
from .numpyextractors import NumpySorting
10-
from .basesorting import minimum_spike_dtype
1111

1212
from probeinterface import Probe, generate_linear_probe, generate_multi_columns_probe
1313

src/spikeinterface/core/node_pipeline.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,12 @@
1010

1111
import numpy as np
1212

13+
from spikeinterface.core.base import base_peak_dtype, spike_peak_dtype
1314
from spikeinterface.core import BaseRecording, get_chunk_with_margin
1415
from spikeinterface.core.job_tools import ChunkRecordingExecutor, fix_job_kwargs, _shared_job_kwargs_doc
1516
from spikeinterface.core import get_channel_distances
1617

1718

18-
base_peak_dtype = [
19-
("sample_index", "int64"),
20-
("channel_index", "int64"),
21-
("amplitude", "float64"),
22-
("segment_index", "int64"),
23-
]
24-
25-
26-
spike_peak_dtype = base_peak_dtype + [
27-
("unit_index", "int64"),
28-
]
29-
30-
3119
class PipelineNode:
3220

3321
# If False (general case) then compute(traces_chunk, *node_input_args)

src/spikeinterface/core/numpyextractors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
BaseSnippets,
1313
BaseSnippetsSegment,
1414
)
15-
from .basesorting import minimum_spike_dtype
15+
from .base import minimum_spike_dtype
1616
from .core_tools import make_shared_array
1717
from .recording_tools import write_memory_recording
1818
from multiprocessing.shared_memory import SharedMemory

src/spikeinterface/core/tests/test_node_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import shutil
55

66
from spikeinterface import create_sorting_analyzer, get_template_extremum_channel, generate_ground_truth_recording
7+
from spikeinterface.core.base import spike_peak_dtype
78
from spikeinterface.core.job_tools import divide_recording_into_chunks
89

10+
911
# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
1012
from spikeinterface.core.node_pipeline import (
1113
run_node_pipeline,
@@ -14,7 +16,6 @@
1416
PipelineNode,
1517
ExtractDenseWaveforms,
1618
sorting_to_peaks,
17-
spike_peak_dtype,
1819
)
1920

2021

src/spikeinterface/core/tests/test_numpy_extractors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
generate_recording,
1414
)
1515

16-
from spikeinterface.core.basesorting import minimum_spike_dtype
16+
from spikeinterface.core.base import minimum_spike_dtype
1717
from spikeinterface.core.testing import check_sortings_equal
1818

1919

src/spikeinterface/core/zarrextractors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77

88
from probeinterface import ProbeGroup
99

10+
from .base import minimum_spike_dtype
1011
from .baserecording import BaseRecording, BaseRecordingSegment
11-
from .basesorting import BaseSorting, SpikeVectorSortingSegment, minimum_spike_dtype
12+
from .basesorting import BaseSorting, SpikeVectorSortingSegment
1213
from .core_tools import define_function_from_class, check_json
1314
from .job_tools import split_job_kwargs
1415
from .core_tools import is_path_remote

0 commit comments

Comments
 (0)