Skip to content

Commit 8b48f4b

Browse files
committed
Handle events offset when multi segment
1 parent 5342c41 commit 8b48f4b

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

neo/rawio/micromedrawio.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def _parse_header(self):
156156
assert np.unique(signal_channels["sampling_rate"]).size == 1
157157
self._sampling_rate = float(np.unique(signal_channels["sampling_rate"])[0])
158158

159-
# TODO change this when multi segment handling
160159
seg_limits = [trace_offset for seg_start, trace_offset in self.info_segments] + [self._raw_signals.shape[0]]
161160
nb_segment = len(self.info_segments)
162161
self._t_starts = []
@@ -191,13 +190,17 @@ def _parse_header(self):
191190
dtype = np.dtype(ev_dtype)
192191
rawevent = np.memmap(self.filename, dtype=dtype, mode="r", offset=pos, shape=length // dtype.itemsize)
193192

194-
keep = (
195-
(rawevent["start"] >= rawevent["start"][0])
196-
& (rawevent["start"] < self._raw_signals.shape[0])
197-
& (rawevent["start"] != 0)
198-
)
199-
rawevent = rawevent[keep]
200-
self._raw_events.append(rawevent)
193+
# important : all events timing are related to the first segment t_start
194+
self._raw_events.append([])
195+
for seg_index in range(nb_segment):
196+
left_lim = seg_limits[seg_index]
197+
right_lim = seg_limits[seg_index + 1]
198+
keep = (
199+
(rawevent["start"] >= left_lim)
200+
& (rawevent["start"] < right_lim)
201+
& (rawevent["start"] != 0)
202+
)
203+
self._raw_events[-1].append(rawevent[keep])
201204

202205
# No spikes
203206
spike_channels = []
@@ -254,22 +257,26 @@ def _spike_count(self, block_index, seg_index, unit_index):
254257
return 0
255258

256259
def _event_count(self, block_index, seg_index, event_channel_index):
257-
n = self._raw_events[event_channel_index].size
260+
n = self._raw_events[event_channel_index][seg_index].size
258261
return n
259262

260263
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
261264

262-
raw_event = self._raw_events[event_channel_index]
265+
raw_event = self._raw_events[event_channel_index][seg_index]
266+
267+
# important : all events timing are related to the first segment t_start
268+
seg_start0, _ = self.info_segments[0]
263269

264270
if t_start is not None:
265-
keep = raw_event["start"] >= int(t_start * self._sampling_rate)
271+
keep = raw_event["start"] + seg_start0 >= int(t_start * self._sampling_rate)
266272
raw_event = raw_event[keep]
267273

268274
if t_stop is not None:
269-
keep = raw_event["start"] <= int(t_stop * self._sampling_rate)
275+
keep = raw_event["start"] + seg_start0 <= int(t_stop * self._sampling_rate)
270276
raw_event = raw_event[keep]
271277

272-
timestamp = raw_event["start"]
278+
timestamp = raw_event["start"] + seg_start0
279+
273280
if event_channel_index < 2:
274281
durations = None
275282
else:
@@ -285,8 +292,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
285292

286293
def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index):
287294
event_times = event_timestamps.astype(dtype) / self._sampling_rate
288-
# event_times += self._global_t_start
289-
return event_times
295+
return event_times
290296

291297
def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index):
292298
durations = raw_duration.astype(dtype) / self._sampling_rate

0 commit comments

Comments
 (0)