Skip to content

Commit dc498f2

Browse files
authored
Merge pull request #1754 from catalystneuro/binary_recording
`num_chan` to `num_channels` in `BinaryRecordingExtractor`
2 parents a23c725 + 4b4e7d4 commit dc498f2

File tree

7 files changed

+72
-34
lines changed

7 files changed

+72
-34
lines changed

examples/modules_gallery/core/plot_1_recording_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
# Note that this new recording is now "on disk" and not "in memory" as the Numpy recording.
8080
# This means that the loading is "lazy" and the data are not loaded in memory.
8181

82-
recording2 = se.BinaryRecordingExtractor(file_paths, sampling_frequency, num_channels, traces0.dtype)
82+
recording2 = se.BinaryRecordingExtractor(file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype)
8383
print(recording2)
8484

8585
##############################################################################

src/spikeinterface/core/baserecording.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def _save(self, format="binary", **save_kwargs):
450450
binary_rec = BinaryRecordingExtractor(
451451
file_paths=file_paths,
452452
sampling_frequency=self.get_sampling_frequency(),
453-
num_chan=self.get_num_channels(),
453+
num_channels=self.get_num_channels(),
454454
dtype=dtype,
455455
t_starts=t_starts,
456456
channel_ids=self.get_channel_ids(),

src/spikeinterface/core/binaryrecordingextractor.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from typing import List, Union
22
import mmap
3-
4-
import shutil
3+
import warnings
54
from pathlib import Path
65

76
import numpy as np
87

98
from .baserecording import BaseRecording, BaseRecordingSegment
10-
from .core_tools import read_binary_recording, write_binary_recording, define_function_from_class
9+
from .core_tools import write_binary_recording, define_function_from_class
1110
from .job_tools import _shared_job_kwargs_doc
1211

1312

@@ -21,7 +20,9 @@ class BinaryRecordingExtractor(BaseRecording):
2120
Path to the binary file
2221
sampling_frequency: float
2322
The sampling frequency
24-
num_chan: int
23+
num_channels: int
24+
Number of channels
25+
num_chan: int [deprecated, use num_channels instead, will be removed as early as v0.100.0]
2526
Number of channels
2627
dtype: str or dtype
2728
The dtype of the binary file
@@ -40,6 +41,10 @@ class BinaryRecordingExtractor(BaseRecording):
4041
is_filtered: bool or None
4142
If True, the recording is assumed to be filtered. If None, is_filtered is not set.
4243
44+
Notes
45+
-----
46+
When both num_channels and num_chan are provided, `num_channels` is used and `num_chan` is ignored.
47+
4348
Returns
4449
-------
4550
recording: BinaryRecordingExtractor
@@ -55,43 +60,50 @@ def __init__(
5560
self,
5661
file_paths,
5762
sampling_frequency,
58-
num_chan,
5963
dtype,
64+
num_channels=None,
6065
t_starts=None,
6166
channel_ids=None,
6267
time_axis=0,
6368
file_offset=0,
6469
gain_to_uV=None,
6570
offset_to_uV=None,
6671
is_filtered=None,
72+
num_chan=None,
6773
):
74+
# This assigns num_channels if num_channels is not None, otherwise num_chan is assigned
75+
num_channels = num_channels or num_chan
76+
assert num_channels is not None, "You must provide num_channels or num_chan"
77+
if num_chan is not None:
78+
warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead")
79+
6880
if channel_ids is None:
69-
channel_ids = list(range(num_chan))
81+
channel_ids = list(range(num_channels))
7082
else:
71-
assert len(channel_ids) == num_chan, "Provided recording channels have the wrong length"
83+
assert len(channel_ids) == num_channels, "Provided recording channels have the wrong length"
7284

7385
BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)
7486

7587
if isinstance(file_paths, list):
7688
# several segment
77-
datfiles = [Path(p) for p in file_paths]
89+
file_path_list = [Path(p) for p in file_paths]
7890
else:
7991
# one segment
80-
datfiles = [Path(file_paths)]
92+
file_path_list = [Path(file_paths)]
8193

8294
if t_starts is not None:
83-
assert len(t_starts) == len(datfiles), "t_starts must be a list of same size than file_paths"
95+
assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths"
8496
t_starts = [float(t_start) for t_start in t_starts]
8597

8698
dtype = np.dtype(dtype)
8799

88-
for i, datfile in enumerate(datfiles):
100+
for i, file_path in enumerate(file_path_list):
89101
if t_starts is None:
90102
t_start = None
91103
else:
92104
t_start = t_starts[i]
93105
rec_segment = BinaryRecordingSegment(
94-
datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset
106+
file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset
95107
)
96108
self.add_recording_segment(rec_segment)
97109

@@ -105,10 +117,11 @@ def __init__(
105117
self.set_channel_offsets(offset_to_uV)
106118

107119
self._kwargs = {
108-
"file_paths": [str(e.absolute()) for e in datfiles],
120+
"file_paths": [str(e.absolute()) for e in file_path_list],
109121
"sampling_frequency": sampling_frequency,
110122
"t_starts": t_starts,
111-
"num_chan": num_chan,
123+
"num_channels": num_channels,
124+
"num_chan": num_channels, # TODO: This should be here at least till version 0.100.0
112125
"dtype": dtype.str,
113126
"channel_ids": channel_ids,
114127
"time_axis": time_axis,
@@ -142,7 +155,7 @@ def get_binary_description(self):
142155
d = dict(
143156
file_paths=self._kwargs["file_paths"],
144157
dtype=np.dtype(self._kwargs["dtype"]),
145-
num_channels=self._kwargs["num_chan"],
158+
num_channels=self._kwargs["num_channels"],
146159
time_axis=self._kwargs["time_axis"],
147160
file_offset=self._kwargs["file_offset"],
148161
)
@@ -155,23 +168,23 @@ def get_binary_description(self):
155168

156169

157170
class BinaryRecordingSegment(BaseRecordingSegment):
158-
def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset):
171+
def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset):
159172
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
160-
self.num_chan = num_chan
173+
self.num_channels = num_channels
161174
self.dtype = np.dtype(dtype)
162175
self.file_offset = file_offset
163176
self.time_axis = time_axis
164177
self.datfile = datfile
165178
self.file = open(self.datfile, "r")
166-
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize)
179+
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize)
167180
if self.time_axis == 0:
168-
self.shape = (self.num_samples, self.num_chan)
181+
self.shape = (self.num_samples, self.num_channels)
169182
else:
170-
self.shape = (self.num_chan, self.num_samples)
183+
self.shape = (self.num_channels, self.num_samples)
171184

172185
byte_offset = self.file_offset
173186
dtype_size_bytes = self.dtype.itemsize
174-
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan
187+
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels
175188
self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
176189
self.memmap_length = data_size_bytes + self.array_offset
177190

src/spikeinterface/core/tests/test_baserecording.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def test_BaseRecording():
3434
for i in range(num_seg):
3535
a = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
3636
a[:] = np.random.randn(*a.shape).astype(dtype)
37-
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
37+
rec = BinaryRecordingExtractor(
38+
file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_chan, dtype=dtype
39+
)
3840

3941
assert rec.get_num_segments() == 2
4042
assert rec.get_num_channels() == 3
@@ -228,14 +230,25 @@ def test_BaseRecording():
228230
assert np.dtype(rec_float32.get_traces().dtype) == np.float32
229231

230232
# test with t_start
231-
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype, t_starts=np.arange(num_seg) * 10.0)
233+
rec = BinaryRecordingExtractor(
234+
file_paths=file_paths,
235+
sampling_frequency=sampling_frequency,
236+
num_channels=num_chan,
237+
dtype=dtype,
238+
t_starts=np.arange(num_seg) * 10.0,
239+
)
232240
times1 = rec.get_times(1)
233241
folder = cache_folder / "recording_with_t_start"
234242
rec2 = rec.save(folder=folder)
235243
assert np.allclose(times1, rec2.get_times(1))
236244

237245
# test with time_vector
238-
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
246+
rec = BinaryRecordingExtractor(
247+
file_paths=file_paths,
248+
sampling_frequency=sampling_frequency,
249+
num_channels=num_chan,
250+
dtype=dtype,
251+
)
239252
rec.set_times(np.arange(num_samples) / sampling_frequency + 30.0, segment_index=0)
240253
rec.set_times(np.arange(num_samples) / sampling_frequency + 40.0, segment_index=1)
241254
times1 = rec.get_times(1)

src/spikeinterface/core/tests/test_binaryrecordingextractor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,21 @@
1313

1414
def test_BinaryRecordingExtractor():
1515
num_seg = 2
16-
num_chan = 3
16+
num_channels = 3
1717
num_samples = 30
1818
sampling_frequency = 10000
1919
dtype = "int16"
2020

2121
file_paths = [cache_folder / f"test_BinaryRecordingExtractor_{i}.raw" for i in range(num_seg)]
2222
for i in range(num_seg):
23-
np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
23+
np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_channels))
2424

25-
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
26-
print(rec)
25+
rec = BinaryRecordingExtractor(
26+
file_paths=file_paths,
27+
sampling_frequency=sampling_frequency,
28+
num_channels=num_channels,
29+
dtype=dtype,
30+
)
2731

2832
file_paths = [cache_folder / f"test_BinaryRecordingExtractor_copied_{i}.raw" for i in range(num_seg)]
2933
BinaryRecordingExtractor.write_recording(rec, file_paths)
@@ -44,9 +48,12 @@ def test_round_trip(tmp_path):
4448
BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path)
4549

4650
sampling_frequency = recording.get_sampling_frequency()
47-
num_chan = recording.get_num_channels()
51+
num_channels = recording.get_num_channels()
4852
binary_recorder = BinaryRecordingExtractor(
49-
file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype
53+
file_paths=file_path,
54+
sampling_frequency=sampling_frequency,
55+
num_channels=num_channels,
56+
dtype=dtype,
5057
)
5158

5259
assert np.allclose(recording.get_traces(), binary_recorder.get_traces())

src/spikeinterface/core/tests/test_channelslicerecording.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,12 @@ def test_ChannelSliceRecording():
2525
for i in range(num_seg):
2626
traces = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
2727
traces[:] = np.arange(3)[None, :]
28-
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
28+
rec = BinaryRecordingExtractor(
29+
file_paths=file_paths,
30+
sampling_frequency=sampling_frequency,
31+
num_channels=num_chan,
32+
dtype=dtype,
33+
)
2934

3035
# keep original ids
3136
rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 2])

src/spikeinterface/extractors/shybridextractors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, file_path):
6565
self,
6666
file_paths=bin_file,
6767
sampling_frequency=float(params["fs"]),
68-
num_chan=nb_channels,
68+
num_channels=nb_channels,
6969
dtype=params["dtype"],
7070
time_axis=time_axis,
7171
)

0 commit comments

Comments
 (0)