diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 4b4e1eb3cd..dfba7d1f22 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -61,7 +61,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest): worker_ctx["byte_offset"] = byte_offest worker_ctx["dtype"] = np.dtype(dtype) - file_dict = {segment_index: open(file_path, "r+") for segment_index, file_path in file_path_dict.items()} + file_dict = {segment_index: open(file_path, "rb+") for segment_index, file_path in file_path_dict.items()} worker_ctx["file_dict"] = file_dict return worker_ctx @@ -151,34 +151,16 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): num_channels = recording.get_num_channels() dtype_size_bytes = np.dtype(dtype).itemsize - # Calculate byte offsets for the start and end frames relative to the entire recording + # Calculate byte offsets for the start frames relative to the entire recording start_byte = byte_offset + start_frame * num_channels * dtype_size_bytes - end_byte = byte_offset + end_frame * num_channels * dtype_size_bytes - # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY - memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) - memmap_offset *= mmap.ALLOCATIONGRANULARITY - - # This maps in bytes the region of the memmap that corresponds to the chunk - length = (end_byte - start_byte) + start_offset - memmap_obj = mmap.mmap(file.fileno(), length=length, access=mmap.ACCESS_WRITE, offset=memmap_offset) - - # To use numpy semantics we use the array interface of the memmap object - num_frames = end_frame - start_frame - shape = (num_frames, num_channels) - memmap_array = np.ndarray(shape=shape, dtype=dtype, buffer=memmap_obj, offset=start_offset) - - # Extract the traces and store them in the memmap array traces = recording.get_traces(start_frame=start_frame, end_frame=end_frame, segment_index=segment_index) - - if traces.dtype != dtype: - traces = traces.astype(dtype, copy=False) - - memmap_array[...] = traces - - memmap_obj.flush() - - memmap_obj.close() + traces = traces.astype(dtype, order="c", copy=False) + + file.seek(start_byte) + file.write(traces.data) + # flush is important!! + file.flush() write_binary_recording.__doc__ = write_binary_recording.__doc__.format(_shared_job_kwargs_doc) diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index fc01122269..cdd5897675 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -26,11 +26,14 @@ def test_BaseRecording(create_cache_folder): num_samples = 30 sampling_frequency = 10000 dtype = "int16" + seed = None + rng = np.random.default_rng(seed=seed) file_paths = [cache_folder / f"test_base_recording_{i}.raw" for i in range(num_seg)] for i in range(num_seg): a = np.memmap(file_paths[i], dtype=dtype, mode="w+", shape=(num_samples, num_chan)) - a[:] = np.random.randn(*a.shape).astype(dtype) + a[:] = rng.normal(scale=5000, size=a.shape).astype(dtype) + rec = BinaryRecordingExtractor( file_paths=file_paths, sampling_frequency=sampling_frequency, num_channels=num_chan, dtype=dtype ) @@ -201,6 +204,7 @@ def test_BaseRecording(create_cache_folder): positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) traces2 = rec2.get_traces(segment_index=0) + assert np.array_equal(traces2, rec_p.get_traces(segment_index=0)) # from probeinterface.plotting import plot_probe_group, plot_probe @@ -468,5 +472,8 @@ def test_time_slice_with_time_vector(): if __name__ == "__main__": - # test_BaseRecording() - test_interleaved_probegroups() + import tempfile + tmp_path = Path(tempfile.mkdtemp()) + + test_BaseRecording(tmp_path) + # test_interleaved_probegroups()