Skip to content

Commit 33a4d31

Browse files
authored
Merge pull request #1742 from h-mayorquin/make_binary_recording_memmap_efficient_II
Make binary recording memmap efficient II (using native memmaps)
2 parents 4ad9e0f + 879bfbc commit 33a4d31

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

src/spikeinterface/core/binaryrecordingextractor.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Union
2+
import mmap
23

34
import shutil
45
from pathlib import Path
@@ -156,25 +157,58 @@ def get_binary_description(self):
156157
class BinaryRecordingSegment(BaseRecordingSegment):
157158
def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset):
158159
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
159-
self._timeseries = read_binary_recording(datfile, num_chan, dtype, time_axis, file_offset)
160+
self.num_chan = num_chan
161+
self.dtype = np.dtype(dtype)
162+
self.file_offset = file_offset
163+
self.time_axis = time_axis
164+
self.datfile = datfile
165+
self.file = open(self.datfile, "r")
166+
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize)
167+
if self.time_axis == 0:
168+
self.shape = (self.num_samples, self.num_chan)
169+
else:
170+
self.shape = (self.num_chan, self.num_samples)
171+
172+
byte_offset = self.file_offset
173+
dtype_size_bytes = self.dtype.itemsize
174+
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan
175+
self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
176+
self.memmap_length = data_size_bytes + self.array_offset
160177

161178
def get_num_samples(self) -> int:
162179
"""Returns the number of samples in this signal block
163180
164181
Returns:
165182
SampleIndex: Number of samples in the signal block
166183
"""
167-
return self._timeseries.shape[0]
184+
return self.num_samples
168185

169186
def get_traces(
170187
self,
171188
start_frame: Union[int, None] = None,
172189
end_frame: Union[int, None] = None,
173190
channel_indices: Union[List, None] = None,
174191
) -> np.ndarray:
175-
traces = self._timeseries[start_frame:end_frame]
192+
length = self.memmap_length
193+
memmap_offset = self.memmap_offset
194+
memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset)
195+
196+
array = np.ndarray.__new__(
197+
np.ndarray,
198+
shape=self.shape,
199+
dtype=self.dtype,
200+
buffer=memmap_obj,
201+
order="C",
202+
offset=self.array_offset,
203+
)
204+
205+
if self.time_axis == 1:
206+
array = array.T
207+
208+
traces = array[start_frame:end_frame]
176209
if channel_indices is not None:
177210
traces = traces[:, channel_indices]
211+
178212
return traces
179213

180214

src/spikeinterface/core/tests/test_binaryrecordingextractor.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pathlib import Path
44

55
from spikeinterface.core import BinaryRecordingExtractor
6-
6+
from spikeinterface.core.numpyextractors import NumpyRecording
77

88
if hasattr(pytest, "global_test_folder"):
99
cache_folder = pytest.global_test_folder / "core"
@@ -32,5 +32,32 @@ def test_BinaryRecordingExtractor():
3232
assert (cache_folder / "test_BinaryRecordingExtractor_copied_0.raw").is_file()
3333

3434

35+
def test_round_trip(tmp_path):
36+
num_channels = 10
37+
num_samples = 50
38+
traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")]
39+
sampling_frequency = 30_000.0
40+
recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency)
41+
42+
file_path = tmp_path / "test_BinaryRecordingExtractor.raw"
43+
dtype = recording.get_dtype()
44+
BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path)
45+
46+
sampling_frequency = recording.get_sampling_frequency()
47+
num_chan = recording.get_num_channels()
48+
binary_recorder = BinaryRecordingExtractor(
49+
file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype
50+
)
51+
52+
assert np.allclose(recording.get_traces(), binary_recorder.get_traces())
53+
54+
start_frame = 200
55+
end_frame = 500
56+
smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
57+
binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame)
58+
59+
np.allclose(smaller_traces, binary_smaller_traces)
60+
61+
3562
if __name__ == "__main__":
3663
test_BinaryRecordingExtractor()

0 commit comments

Comments
 (0)