diff --git a/examples/modules_gallery/core/plot_1_recording_extractor.py b/examples/modules_gallery/core/plot_1_recording_extractor.py index a7c2403c4a..f20bf6497d 100644 --- a/examples/modules_gallery/core/plot_1_recording_extractor.py +++ b/examples/modules_gallery/core/plot_1_recording_extractor.py @@ -79,7 +79,7 @@ # Note that this new recording is now "on disk" and not "in memory" as the Numpy recording. # This means that the loading is "lazy" and the data are not loaded in memory. -recording2 = se.BinaryRecordingExtractor(file_paths, sampling_frequency, num_channels, traces0.dtype) +recording2 = se.BinaryRecordingExtractor(file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=traces0.dtype) print(recording2) ############################################################################## diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index d4d0beb5fa..e3baf12e84 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -450,7 +450,7 @@ def _save(self, format="binary", **save_kwargs): binary_rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=self.get_sampling_frequency(), - num_chan=self.get_num_channels(), + num_channels=self.get_num_channels(), dtype=dtype, t_starts=t_starts, channel_ids=self.get_channel_ids(), diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index c04a1c6ec7..f56cc28667 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -1,13 +1,12 @@ from typing import List, Union import mmap - -import shutil +import warnings from pathlib import Path import numpy as np from .baserecording import BaseRecording, BaseRecordingSegment -from .core_tools import read_binary_recording, write_binary_recording, define_function_from_class +from .core_tools import write_binary_recording, define_function_from_class from .job_tools import _shared_job_kwargs_doc @@ -21,7 +20,9 @@ class BinaryRecordingExtractor(BaseRecording): Path to the binary file sampling_frequency: float The sampling frequency - num_chan: int + num_channels: int + Number of channels + num_chan: int [deprecated, use num_channels instead, will be removed as early as v0.100.0] Number of channels dtype: str or dtype The dtype of the binary file @@ -40,6 +41,10 @@ class BinaryRecordingExtractor(BaseRecording): is_filtered: bool or None If True, the recording is assumed to be filtered. If None, is_filtered is not set. + Notes + ----- + When both num_channels and num_chan are provided, `num_channels` is used and `num_chan` is ignored. + Returns ------- recording: BinaryRecordingExtractor @@ -55,8 +60,8 @@ def __init__( self, file_paths, sampling_frequency, - num_chan, dtype, + num_channels=None, t_starts=None, channel_ids=None, time_axis=0, @@ -64,34 +69,41 @@ def __init__( gain_to_uV=None, offset_to_uV=None, is_filtered=None, + num_chan=None, ): + # This assigns num_channels if num_channels is not None, otherwise num_chan is assigned + num_channels = num_channels or num_chan + assert num_channels is not None, "You must provide num_channels or num_chan" + if num_chan is not None: + warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead") + if channel_ids is None: - channel_ids = list(range(num_chan)) + channel_ids = list(range(num_channels)) else: - assert len(channel_ids) == num_chan, "Provided recording channels have the wrong length" + assert len(channel_ids) == num_channels, "Provided recording channels have the wrong length" BaseRecording.__init__(self, sampling_frequency, channel_ids, dtype) if isinstance(file_paths, list): # several segment - datfiles = [Path(p) for p in file_paths] + file_path_list = [Path(p) for p in file_paths] else: # one segment - datfiles = [Path(file_paths)] + file_path_list = [Path(file_paths)] if t_starts is not None: - assert len(t_starts) == len(datfiles), "t_starts must be a list of same size than file_paths" + assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths" t_starts = [float(t_start) for t_start in t_starts] dtype = np.dtype(dtype) - for i, datfile in enumerate(datfiles): + for i, file_path in enumerate(file_path_list): if t_starts is None: t_start = None else: t_start = t_starts[i] rec_segment = BinaryRecordingSegment( - datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset + file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset ) self.add_recording_segment(rec_segment) @@ -105,10 +117,11 @@ def __init__( self.set_channel_offsets(offset_to_uV) self._kwargs = { - "file_paths": [str(e.absolute()) for e in datfiles], + "file_paths": [str(e.absolute()) for e in file_path_list], "sampling_frequency": sampling_frequency, "t_starts": t_starts, - "num_chan": num_chan, + "num_channels": num_channels, + "num_chan": num_channels, # TODO: This should be here at least till version 0.100.0 "dtype": dtype.str, "channel_ids": channel_ids, "time_axis": time_axis, @@ -142,7 +155,7 @@ def get_binary_description(self): d = dict( file_paths=self._kwargs["file_paths"], dtype=np.dtype(self._kwargs["dtype"]), - num_channels=self._kwargs["num_chan"], + num_channels=self._kwargs["num_channels"], time_axis=self._kwargs["time_axis"], file_offset=self._kwargs["file_offset"], ) @@ -155,23 +168,23 @@ def get_binary_description(self): class BinaryRecordingSegment(BaseRecordingSegment): - def __init__(self, datfile, sampling_frequency, t_start, num_chan, dtype, time_axis, file_offset): + def __init__(self, datfile, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset): BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start) - self.num_chan = num_chan + self.num_channels = num_channels self.dtype = np.dtype(dtype) self.file_offset = file_offset self.time_axis = time_axis self.datfile = datfile self.file = open(self.datfile, "r") - self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_chan * np.dtype(dtype).itemsize) + self.num_samples = (Path(datfile).stat().st_size - file_offset) // (num_channels * np.dtype(dtype).itemsize) if self.time_axis == 0: - self.shape = (self.num_samples, self.num_chan) + self.shape = (self.num_samples, self.num_channels) else: - self.shape = (self.num_chan, self.num_samples) + self.shape = (self.num_channels, self.num_samples) byte_offset = self.file_offset dtype_size_bytes = self.dtype.itemsize - data_size_bytes = dtype_size_bytes * self.num_samples * self.num_chan + data_size_bytes = dtype_size_bytes * self.num_samples * self.num_channels self.memmap_offset, self.array_offset = divmod(byte_offset, mmap.ALLOCATIONGRANULARITY) self.memmap_length = data_size_bytes + self.array_offset diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 230ff40b47..ed9a79d055 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -34,7 +34,9 @@ def test_BaseRecording(): for i in range(num_seg): a = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan)) a[:] = np.random.randn(*a.shape).astype(dtype) - rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype) + rec = BinaryRecordingExtractor( + file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_chan, dtype=dtype + ) assert rec.get_num_segments() == 2 assert rec.get_num_channels() == 3 @@ -228,14 +230,25 @@ def test_BaseRecording(): assert np.dtype(rec_float32.get_traces().dtype) == np.float32 # test with t_start - rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype, t_starts=np.arange(num_seg) * 10.0) + rec = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_chan, + dtype=dtype, + t_starts=np.arange(num_seg) * 10.0, + ) times1 = rec.get_times(1) folder = cache_folder / "recording_with_t_start" rec2 = rec.save(folder=folder) assert np.allclose(times1, rec2.get_times(1)) # test with time_vector - rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype) + rec = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_chan, + dtype=dtype, + ) rec.set_times(np.arange(num_samples) / sampling_frequency + 30.0, segment_index=0) rec.set_times(np.arange(num_samples) / sampling_frequency + 40.0, segment_index=1) times1 = rec.get_times(1) diff --git a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py index 1d2c6e4c21..fb4c3ee3c4 100644 --- a/src/spikeinterface/core/tests/test_binaryrecordingextractor.py +++ b/src/spikeinterface/core/tests/test_binaryrecordingextractor.py @@ -13,17 +13,21 @@ def test_BinaryRecordingExtractor(): num_seg = 2 - num_chan = 3 + num_channels = 3 num_samples = 30 sampling_frequency = 10000 dtype = "int16" file_paths = [cache_folder / f"test_BinaryRecordingExtractor_{i}.raw" for i in range(num_seg)] for i in range(num_seg): - np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan)) + np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_channels)) - rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype) - print(rec) + rec = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, + ) file_paths = [cache_folder / f"test_BinaryRecordingExtractor_copied_{i}.raw" for i in range(num_seg)] BinaryRecordingExtractor.write_recording(rec, file_paths) @@ -44,9 +48,12 @@ def test_round_trip(tmp_path): BinaryRecordingExtractor.write_recording(recording=recording, dtype=dtype, file_paths=file_path) sampling_frequency = recording.get_sampling_frequency() - num_chan = recording.get_num_channels() + num_channels = recording.get_num_channels() binary_recorder = BinaryRecordingExtractor( - file_paths=file_path, sampling_frequency=sampling_frequency, num_chan=num_chan, dtype=dtype + file_paths=file_path, + sampling_frequency=sampling_frequency, + num_channels=num_channels, + dtype=dtype, ) assert np.allclose(recording.get_traces(), binary_recorder.get_traces()) diff --git a/src/spikeinterface/core/tests/test_channelslicerecording.py b/src/spikeinterface/core/tests/test_channelslicerecording.py index c5ba89bd84..08bb22a2c8 100644 --- a/src/spikeinterface/core/tests/test_channelslicerecording.py +++ b/src/spikeinterface/core/tests/test_channelslicerecording.py @@ -25,7 +25,12 @@ def test_ChannelSliceRecording(): for i in range(num_seg): traces = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan)) traces[:] = np.arange(3)[None, :] - rec = BinaryRecordingExtractor(file_paths, sampling_frequency, num_chan, dtype) + rec = BinaryRecordingExtractor( + file_paths=file_paths, + sampling_frequency=sampling_frequency, + num_channels=num_chan, + dtype=dtype, + ) # keep original ids rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 2]) diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index 434f7f5dfa..cb48e3d20f 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -65,7 +65,7 @@ def __init__(self, file_path): self, file_paths=bin_file, sampling_frequency=float(params["fs"]), - num_chan=nb_channels, + num_channels=nb_channels, dtype=params["dtype"], time_axis=time_axis, )