Skip to content

Commit 4061fdf

Browse files
authored
Merge pull request #1603 from iurillilab/fixing-ttl-multichan
OpenEphysBinaryRawIO: Fixing ttl multichan
2 parents 2576037 + 44d4807 commit 4061fdf

File tree

2 files changed

+74
-20
lines changed

2 files changed

+74
-20
lines changed

neo/rawio/openephysbinaryrawio.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _parse_header(self):
217217
if name + "_npy" in info:
218218
data = np.load(info[name + "_npy"], mmap_mode="r")
219219
info[name] = data
220-
221220
# check that events have timestamps
222221
assert "timestamps" in info, "Event stream does not have timestamps!"
223222
# Updates for OpenEphys v0.6:
@@ -253,30 +252,64 @@ def _parse_header(self):
253252
# 'states' was introduced in OpenEphys v0.6. For previous versions, events used 'channel_states'
254253
if "states" in info or "channel_states" in info:
255254
states = info["channel_states"] if "channel_states" in info else info["states"]
255+
256256
if states.size > 0:
257257
timestamps = info["timestamps"]
258258
labels = info["labels"]
259-
rising = np.where(states > 0)[0]
260-
falling = np.where(states < 0)[0]
261259

262-
# infer durations
260+
# Identify unique channels based on state values
261+
channels = np.unique(np.abs(states))
262+
263+
rising_indices = []
264+
falling_indices = []
265+
266+
# all channels are packed into the same `states` array.
267+
# So the states array includes positive and negative values for each channel:
268+
# for example channel one rising would be +1 and channel one falling would be -1,
269+
# channel two rising would be +2 and channel two falling would be -2, etc.
270+
# This is the case for sure for version >= 0.6.x.
271+
for channel in channels:
272+
# Find rising and falling edges for each channel
273+
rising = np.where(states == channel)[0]
274+
falling = np.where(states == -channel)[0]
275+
276+
# Ensure each rising has a corresponding falling
277+
if rising.size > 0 and falling.size > 0:
278+
if rising[0] > falling[0]:
279+
falling = falling[1:]
280+
if rising.size > falling.size:
281+
rising = rising[:-1]
282+
283+
# ensure that the number of rising and falling edges are the same:
284+
if len(rising) != len(falling):
285+
warn(
286+
f"Channel {channel} has {len(rising)} rising edges and "
287+
f"{len(falling)} falling edges. The number of rising and "
288+
f"falling edges should be equal. Skipping events from this channel."
289+
)
290+
continue
291+
292+
rising_indices.extend(rising)
293+
falling_indices.extend(falling)
294+
295+
rising_indices = np.array(rising_indices)
296+
falling_indices = np.array(falling_indices)
297+
298+
# Sort the indices to maintain chronological order
299+
sorted_order = np.argsort(rising_indices)
300+
rising_indices = rising_indices[sorted_order]
301+
falling_indices = falling_indices[sorted_order]
302+
263303
durations = None
264-
if len(states) > 0:
265-
# make sure first event is rising and last is falling
266-
if states[0] < 0:
267-
falling = falling[1:]
268-
if states[-1] > 0:
269-
rising = rising[:-1]
270-
271-
if len(rising) == len(falling):
272-
durations = timestamps[falling] - timestamps[rising]
273-
if not self._use_direct_evt_timestamps:
274-
timestamps = timestamps / info["sample_rate"]
275-
durations = durations / info["sample_rate"]
276-
277-
info["rising"] = rising
278-
info["timestamps"] = timestamps[rising]
279-
info["labels"] = labels[rising]
304+
# if len(rising_indices) == len(falling_indices):
305+
durations = timestamps[falling_indices] - timestamps[rising_indices]
306+
if not self._use_direct_evt_timestamps:
307+
timestamps = timestamps / info["sample_rate"]
308+
durations = durations / info["sample_rate"]
309+
310+
info["rising"] = rising_indices
311+
info["timestamps"] = timestamps[rising_indices]
312+
info["labels"] = labels[rising_indices]
280313
info["durations"] = durations
281314

282315
# no spike read yet

neo/test/rawiotest/test_openephysbinaryrawio.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from neo.rawio.openephysbinaryrawio import OpenEphysBinaryRawIO
44
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
55

6+
import numpy as np
7+
68

79
class TestOpenEphysBinaryRawIO(BaseTestRawIO, unittest.TestCase):
810
rawioclass = OpenEphysBinaryRawIO
@@ -57,6 +59,25 @@ def test_missing_folders(self):
5759
)
5860
rawio.parse_header()
5961

62+
def test_multiple_ttl_events_parsing(self):
63+
rawio = OpenEphysBinaryRawIO(
64+
self.get_local_path("openephysbinary/v0.6.x_neuropixels_with_sync"), load_sync_channel=False
65+
)
66+
rawio.parse_header()
67+
rawio.header = rawio.header
68+
# Testing co
69+
# This is the TTL events from the NI Board channel
70+
ttl_events = rawio._evt_streams[0][0][1]
71+
assert "rising" in ttl_events.keys()
72+
assert "labels" in ttl_events.keys()
73+
assert "durations" in ttl_events.keys()
74+
assert "timestamps" in ttl_events.keys()
75+
76+
# Check that durations of different event streams are correctly parsed:
77+
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "1"], 0.5, atol=0.001)
78+
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "6"], 0.025, atol=0.001)
79+
assert np.allclose(ttl_events["durations"][ttl_events["labels"] == "7"], 0.016666, atol=0.001)
80+
6081

6182
if __name__ == "__main__":
6283
unittest.main()

0 commit comments

Comments
 (0)