Skip to content
40 changes: 37 additions & 3 deletions src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Union
import mmap

import shutil
from pathlib import Path
Expand Down Expand Up @@ -156,25 +157,58 @@ def get_binary_description(self):
class BinaryRecordingSegment(BaseRecordingSegment):
def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset):
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
self._timeseries = read_binary_recording(datfile, num_chan, dtype, time_axis, file_offset)
self.num_chan = num_chan
self.dtype = np.dtype(dtype)
self.file_offset = file_offset
self.time_axis = time_axis
self.datfile = datfile
self.file = open(self.datfile, "r")
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize)
if self.time_axis == 0:
self.shape = (self.num_samples, self.num_chan)
else:
self.shape = (self.num_chan, self.num_samples)

byte_offset = self.file_offset
dtype_size_bytes = self.dtype.itemsize
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan
self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
self.memmap_length = data_size_bytes + self.array_offset

def get_num_samples(self) -> int:
"""Returns the number of samples in this signal block

Returns:
SampleIndex: Number of samples in the signal block
"""
return self._timeseries.shape[0]
return self.num_samples

def get_traces(
self,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
channel_indices: Union[List, None] = None,
) -> np.ndarray:
traces = self._timeseries[start_frame:end_frame]
length = self.memmap_length
memmap_offset = self.memmap_offset
memmap_obj = mmap.mmap(self.file.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset)

array = np.ndarray.__new__(
np.ndarray,
shape=self.shape,
dtype=self.dtype,
buffer=memmap_obj,
order="C",
offset=self.array_offset,
)

if self.time_axis == 1:
array = array.T

traces = array[start_frame:end_frame]
if channel_indices is not None:
traces = traces[:, channel_indices]

return traces


Expand Down
29 changes: 28 additions & 1 deletion src/spikeinterface/core/tests/test_binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

from spikeinterface.core import BinaryRecordingExtractor

from spikeinterface.core.numpyextractors import NumpyRecording

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
Expand Down Expand Up @@ -32,5 +32,32 @@ def test_BinaryRecordingExtractor():
assert (cache_folder / "test_BinaryRecordingExtractor_copied_0.raw").is_file()


def test_round_trip(tmp_path):
num_channels = 10
num_samples = 50
traces_list = [np.ones(shape=(num_samples, num_channels), dtype="int32")]
sampling_frequency = 30_000.0
recording = NumpyRecording(traces_list=traces_list, sampling_frequency=sampling_frequency)

file_path = tmp_path / "test_BinaryRecordingExtractor.raw"
dtype = recording.get_dtype()
BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path)

sampling_frequency = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
binary_recorder = BinaryRecordingExtractor(
file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype
)

assert np.allclose(recording.get_traces(), binary_recorder.get_traces())

start_frame = 200
end_frame = 500
smaller_traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame)
binary_smaller_traces = binary_recorder.get_traces(start_frame=start_frame, end_frame=end_frame)

np.allclose(smaller_traces, binary_smaller_traces)


if __name__ == "__main__":
test_BinaryRecordingExtractor()