Skip to content
Merged
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
if kwargs.get("sharedmem", True):
from .numpyextractors import SharedMemoryRecording

cached = SharedMemoryRecording.from_recording(self, **job_kwargs)
cached = SharedMemoryRecording.from_recording(self, t_starts=t_starts, **job_kwargs)
else:
from spikeinterface.core import NumpyRecording

cached = NumpyRecording.from_recording(self, **job_kwargs)
cached = NumpyRecording.from_recording(self, t_starts=t_starts, **job_kwargs)

elif format == "zarr":
from .zarrextractors import ZarrRecordingExtractor
Expand Down
13 changes: 6 additions & 7 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
}

@staticmethod
def from_recording(source_recording, **job_kwargs):
def from_recording(source_recording, t_starts=None, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, dtype=None, **job_kwargs)
if shms[0] is not None:
# if the computation was done in parallel then traces_list is shared array
Expand All @@ -95,13 +95,14 @@ def from_recording(source_recording, **job_kwargs):
for shm in shms:
shm.close()
shm.unlink()
# TODO later : propagte t_starts ?

recording = NumpyRecording(
traces_list,
source_recording.get_sampling_frequency(),
t_starts=None,
t_starts=t_starts,
channel_ids=source_recording.channel_ids,
)
return recording


class NumpyRecordingSegment(BaseRecordingSegment):
Expand Down Expand Up @@ -211,18 +212,16 @@ def __del__(self):
shm.unlink()

@staticmethod
def from_recording(source_recording, **job_kwargs):
def from_recording(source_recording, t_starts=None, **job_kwargs):
traces_list, shms = write_memory_recording(source_recording, buffer_type="sharedmem", **job_kwargs)

# TODO later : propagte t_starts ?

recording = SharedMemoryRecording(
shm_names=[shm.name for shm in shms],
shape_list=[traces.shape for traces in traces_list],
dtype=source_recording.dtype,
sampling_frequency=source_recording.sampling_frequency,
channel_ids=source_recording.channel_ids,
t_starts=None,
t_starts=t_starts,
main_shm_owner=True,
)

Expand Down