Skip to content

Commit 4a6bb20

Browse files
authored
Merge pull request #1383 from manimoh/extract_events_from_SpikeGLX
Extract events from spike glx
2 parents 0401de0 + 15a199f commit 4a6bb20

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

neo/rawio/spikeglxrawio.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,22 @@ def _parse_header(self):
177177

178178
# No events
179179
event_channels = []
180+
# This is true only in case of 'nidq' stream
181+
for stream_name in stream_names:
182+
if "nidq" in stream_name:
183+
info = self.signals_info_dict[0, stream_name]
184+
if len(info["digital_channels"]) > 0:
185+
# add event channels
186+
for local_chan in info["digital_channels"]:
187+
chan_name = local_chan
188+
chan_id = f"{stream_name}#{chan_name}"
189+
event_channels.append((chan_name, chan_id, "event"))
190+
# add events_memmap
191+
data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C")
192+
data = data.reshape(-1, info["num_chan"])
193+
# The digital word is usually the last channel, after all the individual analog channels
194+
extracted_word = data[:, len(info["analog_channels"])]
195+
self._events_memmap = np.unpackbits(extracted_word.astype(np.uint8)[:, None], axis=1)
180196
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
181197

182198
# No spikes
@@ -272,6 +288,44 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
272288

273289
return raw_signals
274290

291+
def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
292+
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)
293+
return timestamps.size
294+
295+
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start=None, t_stop=None):
296+
timestamps, durations, labels = [], None, []
297+
info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream
298+
dig_ch = info["digital_channels"]
299+
if len(dig_ch) > 0:
300+
event_data = self._events_memmap
301+
channel = dig_ch[event_channel_index]
302+
ch_idx = 7 - int(channel[2:]) # They are in the reverse order
303+
this_stream = event_data[:, ch_idx]
304+
this_rising = np.where(np.diff(this_stream) == 1)[0] + 1
305+
this_falling = (
306+
np.where(np.diff(this_stream) == 255)[0] + 1
307+
) # because the data is in unsigned 8 bit, -1 = 255!
308+
if len(this_rising) > 0:
309+
timestamps.extend(this_rising)
310+
labels.extend([f"{channel} ON"] * len(this_rising))
311+
if len(this_falling) > 0:
312+
timestamps.extend(this_falling)
313+
labels.extend([f"{channel} OFF"] * len(this_falling))
314+
timestamps = np.asarray(timestamps)
315+
if len(labels) == 0:
316+
labels = np.asarray(labels, dtype="U1")
317+
else:
318+
labels = np.asarray(labels)
319+
return timestamps, durations, labels
320+
321+
def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index):
322+
info = self.signals_info_dict[0, "nidq"] # There are no events that are not in the nidq stream
323+
event_times = event_timestamps.astype(dtype) / float(info["sampling_rate"])
324+
return event_times
325+
326+
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
327+
return None
328+
275329

276330
def scan_files(dirname):
277331
"""
@@ -507,4 +561,14 @@ def extract_stream_info(meta_file, meta):
507561
info["channel_offsets"] = np.zeros(info["num_chan"])
508562
info["has_sync_trace"] = has_sync_trace
509563

564+
if "nidq" in device:
565+
info["digital_channels"] = []
566+
info["analog_channels"] = [channel for channel in info["channel_names"] if not channel.startswith("XD")]
567+
# Digital/event channels are encoded within the digital word, so that will need more handling
568+
for item in meta["niXDChans1"].split(","):
569+
if ":" in item:
570+
start, end = map(int, item.split(":"))
571+
info["digital_channels"].extend([f"XD{i}" for i in range(start, end + 1)])
572+
else:
573+
info["digital_channels"].append(f"XD{int(item)}")
510574
return info

neo/test/rawiotest/test_spikeglxrawio.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from neo.rawio.spikeglxrawio import SpikeGLXRawIO
88
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
9+
import numpy as np
910

1011

1112
class TestSpikeGLXRawIO(BaseTestRawIO, unittest.TestCase):
@@ -87,6 +88,21 @@ def test_subset_with_sync(self):
8788
)
8889
assert chunk.shape[1] == 120
8990

91+
def test_nidq_digital_channel(self):
92+
rawio_digital = SpikeGLXRawIO(self.get_local_path("spikeglx/DigitalChannelTest_g0"))
93+
rawio_digital.parse_header()
94+
# This data should have 8 event channels
95+
assert np.shape(rawio_digital.header["event_channels"])[0] == 8
96+
97+
# Channel 0 in this data will have sync pulses at 1 Hz, let's confirm that
98+
all_events = rawio_digital.get_event_timestamps(0, 0, 0)
99+
on_events = np.where(all_events[2] == "XD0 ON")
100+
on_ts = all_events[0][on_events]
101+
on_ts_scaled = rawio_digital.rescale_event_timestamp(on_ts)
102+
on_diff = np.diff(on_ts_scaled)
103+
atol = 0.001
104+
assert np.allclose(on_diff, 1, atol=atol)
105+
90106

91107
if __name__ == "__main__":
92108
unittest.main()

0 commit comments

Comments
 (0)