Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
# Note that this new recording is now "on disk" and not "in memory" as the Numpy recording.
# This means that the loading is "lazy" and the data are not loaded in memory.

recording2 = se.BinaryRecordingExtractor(file_paths, sampling_frequency, num_channels, traces0.dtype)
recording2 = se.BinaryRecordingExtractor(file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype)
print(recording2)

##############################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def _save(self, format="binary", **save_kwargs):
binary_rec = BinaryRecordingExtractor(
file_paths=file_paths,
sampling_frequency=self.get_sampling_frequency(),
num_chan=self.get_num_channels(),
num_channels=self.get_num_channels(),
dtype=dtype,
t_starts=t_starts,
channel_ids=self.get_channel_ids(),
Expand Down
55 changes: 34 additions & 21 deletions src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import List, Union
import mmap

import shutil
import warnings
from pathlib import Path

import numpy as np

from .baserecording import BaseRecording, BaseRecordingSegment
from .core_tools import read_binary_recording, write_binary_recording, define_function_from_class
from .core_tools import write_binary_recording, define_function_from_class
from .job_tools import _shared_job_kwargs_doc


Expand All @@ -21,7 +20,9 @@ class BinaryRecordingExtractor(BaseRecording):
Path to the binary file
sampling_frequency: float
The sampling frequency
num_chan: int
num_channels: int
Number of channels
num_chan: int [deprecated, use num_channels instead, will be removed as early as v0.100.0]
Number of channels
dtype: str or dtype
The dtype of the binary file
Expand All @@ -40,6 +41,10 @@ class BinaryRecordingExtractor(BaseRecording):
is_filtered: bool or None
If True, the recording is assumed to be filtered. If None, is_filtered is not set.

Notes
-----
When both num_channels and num_chan are provided, `num_channels` is used and `num_chan` is ignored.

Returns
-------
recording: BinaryRecordingExtractor
Expand All @@ -55,43 +60,50 @@ def __init__(
self,
file_paths,
sampling_frequency,
num_chan,
dtype,
num_channels=None,
t_starts=None,
channel_ids=None,
time_axis=0,
file_offset=0,
gain_to_uV=None,
offset_to_uV=None,
is_filtered=None,
num_chan=None,
):
# This assigns num_channels if num_channels is not None, otherwise num_chan is assigned
num_channels = num_channels or num_chan
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if they are both given? and maybe different? e.g. num_channels=16 and num_chan=32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This:
https://en.wikipedia.org/wiki/Short-circuit_evaluation

Together with this:
https://docs.python.org/3/library/stdtypes.html#truth-value-testing

It is a common python idiom but maybe I should make the point of avoiding it as it can be confusing.

assert num_channels is not None, "You must provide num_channels or num_chan"
if num_chan is not None:
warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead")

if channel_ids is None:
channel_ids = list(range(num_chan))
channel_ids = list(range(num_channels))
else:
assert len(channel_ids) == num_chan, "Provided recording channels have the wrong length"
assert len(channel_ids) == num_channels, "Provided recording channels have the wrong length"

BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype)

if isinstance(file_paths, list):
# several segment
datfiles = [Path(p) for p in file_paths]
file_path_list = [Path(p) for p in file_paths]
else:
# one segment
datfiles = [Path(file_paths)]
file_path_list = [Path(file_paths)]

if t_starts is not None:
assert len(t_starts) == len(datfiles), "t_starts must be a list of same size than file_paths"
assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths"
t_starts = [float(t_start) for t_start in t_starts]

dtype = np.dtype(dtype)

for i, datfile in enumerate(datfiles):
for i, file_path in enumerate(file_path_list):
if t_starts is None:
t_start = None
else:
t_start = t_starts[i]
rec_segment = BinaryRecordingSegment(
datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset
file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset
)
self.add_recording_segment(rec_segment)

Expand All @@ -105,10 +117,11 @@ def __init__(
self.set_channel_offsets(offset_to_uV)

self._kwargs = {
"file_paths": [str(e.absolute()) for e in datfiles],
"file_paths": [str(e.absolute()) for e in file_path_list],
"sampling_frequency": sampling_frequency,
"t_starts": t_starts,
"num_chan": num_chan,
"num_channels": num_channels,
"num_chan": num_channels, # TODO: This should be here at least till version 0.100.0
"dtype": dtype.str,
"channel_ids": channel_ids,
"time_axis": time_axis,
Expand Down Expand Up @@ -142,7 +155,7 @@ def get_binary_description(self):
d = dict(
file_paths=self._kwargs["file_paths"],
dtype=np.dtype(self._kwargs["dtype"]),
num_channels=self._kwargs["num_chan"],
num_channels=self._kwargs["num_channels"],
time_axis=self._kwargs["time_axis"],
file_offset=self._kwargs["file_offset"],
)
Expand All @@ -155,23 +168,23 @@ def get_binary_description(self):


class BinaryRecordingSegment(BaseRecordingSegment):
def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset):
def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset):
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
self.num_chan = num_chan
self.num_channels = num_channels
self.dtype = np.dtype(dtype)
self.file_offset = file_offset
self.time_axis = time_axis
self.datfile = datfile
self.file = open(self.datfile, "r")
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize)
self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize)
if self.time_axis == 0:
self.shape = (self.num_samples, self.num_chan)
self.shape = (self.num_samples, self.num_channels)
else:
self.shape = (self.num_chan, self.num_samples)
self.shape = (self.num_channels, self.num_samples)

byte_offset = self.file_offset
dtype_size_bytes = self.dtype.itemsize
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan
data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels
self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY)
self.memmap_length = data_size_bytes + self.array_offset

Expand Down
19 changes: 16 additions & 3 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def test_BaseRecording():
for i in range(num_seg):
a = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
a[:] = np.random.randn(*a.shape).astype(dtype)
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
rec = BinaryRecordingExtractor(
file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_chan, dtype=dtype
)

assert rec.get_num_segments() == 2
assert rec.get_num_channels() == 3
Expand Down Expand Up @@ -228,14 +230,25 @@ def test_BaseRecording():
assert np.dtype(rec_float32.get_traces().dtype) == np.float32

# test with t_start
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype, t_starts=np.arange(num_seg) * 10.0)
rec = BinaryRecordingExtractor(
file_paths=file_paths,
sampling_frequency=sampling_frequency,
num_channels=num_chan,
dtype=dtype,
t_starts=np.arange(num_seg) * 10.0,
)
times1 = rec.get_times(1)
folder = cache_folder / "recording_with_t_start"
rec2 = rec.save(folder=folder)
assert np.allclose(times1, rec2.get_times(1))

# test with time_vector
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
rec = BinaryRecordingExtractor(
file_paths=file_paths,
sampling_frequency=sampling_frequency,
num_channels=num_chan,
dtype=dtype,
)
rec.set_times(np.arange(num_samples) / sampling_frequency + 30.0, segment_index=0)
rec.set_times(np.arange(num_samples) / sampling_frequency + 40.0, segment_index=1)
times1 = rec.get_times(1)
Expand Down
19 changes: 13 additions & 6 deletions src/spikeinterface/core/tests/test_binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@

def test_BinaryRecordingExtractor():
num_seg = 2
num_chan = 3
num_channels = 3
num_samples = 30
sampling_frequency = 10000
dtype = "int16"

file_paths = [cache_folder / f"test_BinaryRecordingExtractor_{i}.raw" for i in range(num_seg)]
for i in range(num_seg):
np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_channels))

rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
print(rec)
rec = BinaryRecordingExtractor(
file_paths=file_paths,
sampling_frequency=sampling_frequency,
num_channels=num_channels,
dtype=dtype,
)

file_paths = [cache_folder / f"test_BinaryRecordingExtractor_copied_{i}.raw" for i in range(num_seg)]
BinaryRecordingExtractor.write_recording(rec, file_paths)
Expand All @@ -44,9 +48,12 @@ def test_round_trip(tmp_path):
BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path)

sampling_frequency = recording.get_sampling_frequency()
num_chan = recording.get_num_channels()
num_channels = recording.get_num_channels()
binary_recorder = BinaryRecordingExtractor(
file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype
file_paths=file_path,
sampling_frequency=sampling_frequency,
num_channels=num_channels,
dtype=dtype,
)

assert np.allclose(recording.get_traces(), binary_recorder.get_traces())
Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ def test_ChannelSliceRecording():
for i in range(num_seg):
traces = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan))
traces[:] = np.arange(3)[None, :]
rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype)
rec = BinaryRecordingExtractor(
file_paths=file_paths,
sampling_frequency=sampling_frequency,
num_channels=num_chan,
dtype=dtype,
)

# keep original ids
rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 2])
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/shybridextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, file_path):
self,
file_paths=bin_file,
sampling_frequency=float(params["fs"]),
num_chan=nb_channels,
num_channels=nb_channels,
dtype=params["dtype"],
time_axis=time_axis,
)
Expand Down