@@ -156,7 +156,6 @@ def _parse_header(self):
156156 assert np .unique (signal_channels ["sampling_rate" ]).size == 1
157157 self ._sampling_rate = float (np .unique (signal_channels ["sampling_rate" ])[0 ])
158158
159- # TODO change this when multi segment handling
160159 seg_limits = [trace_offset for seg_start , trace_offset in self .info_segments ] + [self ._raw_signals .shape [0 ]]
161160 nb_segment = len (self .info_segments )
162161 self ._t_starts = []
@@ -191,13 +190,17 @@ def _parse_header(self):
191190 dtype = np .dtype (ev_dtype )
192191 rawevent = np .memmap (self .filename , dtype = dtype , mode = "r" , offset = pos , shape = length // dtype .itemsize )
193192
194- keep = (
195- (rawevent ["start" ] >= rawevent ["start" ][0 ])
196- & (rawevent ["start" ] < self ._raw_signals .shape [0 ])
197- & (rawevent ["start" ] != 0 )
198- )
199- rawevent = rawevent [keep ]
200- self ._raw_events .append (rawevent )
193+ # important : all events timing are related to the first segment t_start
194+ self ._raw_events .append ([])
195+ for seg_index in range (nb_segment ):
196+ left_lim = seg_limits [seg_index ]
197+ right_lim = seg_limits [seg_index + 1 ]
198+ keep = (
199+ (rawevent ["start" ] >= left_lim )
200+ & (rawevent ["start" ] < right_lim )
201+ & (rawevent ["start" ] != 0 )
202+ )
203+ self ._raw_events [- 1 ].append (rawevent [keep ])
201204
202205 # No spikes
203206 spike_channels = []
@@ -254,22 +257,26 @@ def _spike_count(self, block_index, seg_index, unit_index):
254257 return 0
255258
256259 def _event_count (self , block_index , seg_index , event_channel_index ):
257- n = self ._raw_events [event_channel_index ].size
260+ n = self ._raw_events [event_channel_index ][ seg_index ] .size
258261 return n
259262
260263 def _get_event_timestamps (self , block_index , seg_index , event_channel_index , t_start , t_stop ):
261264
262- raw_event = self ._raw_events [event_channel_index ]
265+ raw_event = self ._raw_events [event_channel_index ][seg_index ]
266+
267+ # important : all events timing are related to the first segment t_start
268+ seg_start0 , _ = self .info_segments [0 ]
263269
264270 if t_start is not None :
265- keep = raw_event ["start" ] >= int (t_start * self ._sampling_rate )
271+ keep = raw_event ["start" ] + seg_start0 >= int (t_start * self ._sampling_rate )
266272 raw_event = raw_event [keep ]
267273
268274 if t_stop is not None :
269- keep = raw_event ["start" ] <= int (t_stop * self ._sampling_rate )
275+ keep = raw_event ["start" ] + seg_start0 <= int (t_stop * self ._sampling_rate )
270276 raw_event = raw_event [keep ]
271277
272- timestamp = raw_event ["start" ]
278+ timestamp = raw_event ["start" ] + seg_start0
279+
273280 if event_channel_index < 2 :
274281 durations = None
275282 else :
@@ -285,8 +292,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
285292
286293 def _rescale_event_timestamp (self , event_timestamps , dtype , event_channel_index ):
287294 event_times = event_timestamps .astype (dtype ) / self ._sampling_rate
288- # event_times += self._global_t_start
289- return event_times
295+ return event_times
290296
291297 def _rescale_epoch_duration (self , raw_duration , dtype , event_channel_index ):
292298 durations = raw_duration .astype (dtype ) / self ._sampling_rate
0 commit comments