Skip to content

Commit e803b85

Browse files
author
sprenger
committed
[edf] add support for events & epochs
1 parent ea6f182 commit e803b85

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

neo/rawio/edfrawio.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
14-
_spike_channel_dtype, _event_channel_dtype)
14+
_spike_channel_dtype, _event_channel_dtype)
1515

1616
import numpy as np
1717

@@ -114,6 +114,8 @@ def _parse_header(self):
114114
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
115115

116116
event_channels = []
117+
event_channels.append(('Event', 'event_channel', 'event'))
118+
event_channels.append(('Epoch', 'epoch_channel', 'epoch'))
117119
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
118120

119121
self.header = {}
@@ -213,16 +215,38 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, spike_channel_index,
213215
return None
214216

215217
def _event_count(self, block_index, seg_index, event_channel_index):
216-
return None
218+
return len(self.edf_reader.readAnnotations()[0])
217219

218220
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
219-
return None
221+
# these time should be already in seconds
222+
timestamps, durations, labels = self.edf_reader.readAnnotations()
223+
if t_start is None:
224+
t_start = self.segment_t_start(block_index, seg_index)
225+
if t_stop is None:
226+
t_stop = self.segment_t_stop(block_index, seg_index)
227+
228+
# only consider events and epochs that overlap with t_start t_stop range
229+
time_mask = ((t_start < timestamps) & (timestamps < t_stop)) | \
230+
((t_start < (timestamps+durations)) & ((timestamps+durations) < t_stop))
231+
232+
# separate event from epoch times
233+
event_mask = durations[time_mask] == 0
234+
if self.header['event_channels']['type'][event_channel_index] == b'epoch':
235+
event_mask = ~event_mask
236+
durations = durations[time_mask][event_mask]
237+
elif self.header['event_channels']['type'][event_channel_index] == b'event':
238+
durations = None
239+
240+
times = timestamps[time_mask][event_mask]
241+
labels = labels[time_mask][event_mask]
242+
243+
return times, durations, labels
220244

221245
def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index):
222-
return None
246+
return np.asarray(event_timestamps, dtype=dtype)
223247

224248
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
225-
return None
249+
return np.asarray(raw_duration, dtype=dtype)
226250

227251
def __enter__(self):
228252
return self

0 commit comments

Comments
 (0)