Skip to content

Commit 2106198

Browse files
committed
spikeglx sync PR
1 parent b3dd09d commit 2106198

File tree

2 files changed

+71
-8
lines changed

2 files changed

+71
-8
lines changed

neo/rawio/spikeglxrawio.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from pathlib import Path
5454
import os
5555
import re
56+
from warnings import warn
5657

5758
import numpy as np
5859

@@ -109,6 +110,12 @@ def __init__(self, dirname="", load_sync_channel=False, load_channel_location=Fa
109110
BaseRawWithBufferApiIO.__init__(self)
110111
self.dirname = dirname
111112
self.load_sync_channel = load_sync_channel
113+
if load_sync_channel:
114+
warn(
115+
"The load_sync_channel=True option is deprecated and will be removed in version 0.15. "
116+
"Use load_sync_channel=False instead, which will add sync channels as separate streams.",
117+
DeprecationWarning, stacklevel=2
118+
)
112119
self.load_channel_location = load_channel_location
113120

114121
def _source_name(self):
@@ -152,6 +159,8 @@ def _parse_header(self):
152159
signal_buffers = []
153160
signal_streams = []
154161
signal_channels = []
162+
sync_stream_id_to_buffer_id = {}
163+
155164
for stream_name in stream_names:
156165
# take first segment
157166
info = self.signals_info_dict[0, stream_name]
@@ -168,6 +177,16 @@ def _parse_header(self):
168177
for local_chan in range(info["num_chan"]):
169178
chan_name = info["channel_names"][local_chan]
170179
chan_id = f"{stream_name}#{chan_name}"
180+
181+
# Sync channel
182+
if "nidq" not in stream_name and "SY0" in chan_name and not self.load_sync_channel and local_chan == info["num_chan"] - 1:
183+
# This is a sync channel and should be added as its own stream
184+
sync_stream_id = f"{stream_name}-SYNC"
185+
sync_stream_id_to_buffer_id[sync_stream_id] = buffer_id
186+
stream_id_for_chan = sync_stream_id
187+
else:
188+
stream_id_for_chan = stream_id
189+
171190
signal_channels.append(
172191
(
173192
chan_name,
@@ -177,25 +196,33 @@ def _parse_header(self):
177196
info["units"],
178197
info["channel_gains"][local_chan],
179198
info["channel_offsets"][local_chan],
180-
stream_id,
199+
stream_id_for_chan,
181200
buffer_id,
182201
)
183202
)
184203

185-
# all channel by dafult unless load_sync_channel=False
204+
# all channel by default unless load_sync_channel=False
186205
self._stream_buffer_slice[stream_id] = None
206+
187207
# check sync channel validity
188208
if "nidq" not in stream_name:
189209
if not self.load_sync_channel and info["has_sync_trace"]:
190-
# the last channel is remove from the stream but not from the buffer
191-
last_chan = signal_channels[-1]
192-
last_chan = last_chan[:-2] + ("", buffer_id)
193-
signal_channels = signal_channels[:-1] + [last_chan]
210+
# the last channel is removed from the stream but not from the buffer
194211
self._stream_buffer_slice[stream_id] = slice(0, -1)
212+
213+
# Add a buffer slice for the sync channel
214+
sync_stream_id = f"{stream_name}-SYNC"
215+
self._stream_buffer_slice[sync_stream_id] = slice(-1, None)
216+
195217
if self.load_sync_channel and not info["has_sync_trace"]:
196218
raise ValueError("SYNC channel is not present in the recording. " "Set load_sync_channel to False")
197219

198220
signal_buffers = np.array(signal_buffers, dtype=_signal_buffer_dtype)
221+
222+
# Add sync channels as their own streams
223+
for sync_stream_id, buffer_id in sync_stream_id_to_buffer_id.items():
224+
signal_streams.append((sync_stream_id, sync_stream_id, buffer_id))
225+
199226
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
200227
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
201228

@@ -237,6 +264,14 @@ def _parse_header(self):
237264
t_start = frame_start / sampling_frequency
238265

239266
self._t_starts[stream_name][seg_index] = t_start
267+
268+
# This need special logic because sync not present in stream_names
269+
if f"{stream_name}-SYNC" in signal_streams["name"]:
270+
sync_stream_name = f"{stream_name}-SYNC"
271+
if sync_stream_name not in self._t_starts:
272+
self._t_starts[sync_stream_name] = {}
273+
self._t_starts[sync_stream_name][seg_index] = t_start
274+
240275
t_stop = info["sample_length"] / info["sampling_rate"]
241276
self._t_stops[seg_index] = max(self._t_stops[seg_index], t_stop)
242277

@@ -265,7 +300,11 @@ def _parse_header(self):
265300
if self.load_channel_location:
266301
# need probeinterface to be installed
267302
import probeinterface
268-
303+
304+
# Skip for sync streams
305+
if "SYNC" in stream_name:
306+
continue
307+
269308
info = self.signals_info_dict[seg_index, stream_name]
270309
if "imroTbl" in info["meta"] and info["stream_kind"] == "ap":
271310
# only for ap channel

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_loading_only_one_probe_in_multi_probe_scenario(self):
5555
rawio = SpikeGLXRawIO(probe_folder_path)
5656
rawio.parse_header()
5757

58-
expected_stream_names = ["imec1.ap", "imec1.lf"]
58+
expected_stream_names = ["imec1.ap", "imec1.lf", "imec1.ap-SYNC", "imec1.lf-SYNC"]
5959
actual_stream_names = rawio.header["signal_streams"]["name"].tolist()
6060
assert (
6161
actual_stream_names == expected_stream_names
@@ -130,6 +130,30 @@ def test_nidq_digital_channel(self):
130130
atol = 0.001
131131
assert np.allclose(on_diff, 1, atol=atol)
132132

133+
def test_sync_channel_as_separate_stream(self):
134+
"""Test that sync channel is added as its own stream when load_sync_channel=False."""
135+
import warnings
136+
137+
# Test with load_sync_channel=False (default)
138+
rawio_no_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=False)
139+
rawio_no_sync.parse_header()
140+
141+
# Get stream names
142+
stream_names = rawio_no_sync.header["signal_streams"]["name"].tolist()
143+
144+
# Check if there's a sync channel stream (should contain "SY0" or "SYNC" in the name)
145+
sync_streams = [name for name in stream_names if "SY0" in name or "SYNC" in name]
146+
assert len(sync_streams) > 0, "No sync channel stream found when load_sync_channel=False"
147+
148+
# Test deprecation warning when load_sync_channel=True
149+
with warnings.catch_warnings(record=True) as w:
150+
warnings.simplefilter("always")
151+
rawio_with_sync = SpikeGLXRawIO(self.get_local_path("spikeglx/NP2_with_sync"), load_sync_channel=True)
152+
153+
# Check if deprecation warning was raised
154+
assert any(issubclass(warning.category, DeprecationWarning) for warning in w), "No deprecation warning raised"
155+
assert any("will be removed in version 0.15" in str(warning.message) for warning in w), "Deprecation warning message is incorrect"
156+
133157
def test_t_start_reading(self):
134158
"""Test that t_start values are correctly read for all streams and segments."""
135159

0 commit comments

Comments
 (0)