Skip to content

Commit 60de975

Browse files
sprengerJuliaSprenger
authored andcommitted
[neuralynx] load only spikes/events within time range of segments and adjust tests
1 parent 7a76dda commit 60de975

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,13 @@ def _spike_count(self, block_index, seg_index, unit_index):
597597
data = self._spike_memmap[chan_uid]
598598
ts = data['timestamp']
599599

600-
ts0, ts1 = self._timestamp_limits[seg_index]
600+
ts0 = self.segment_t_start(block_index, seg_index)
601+
ts1 = self.segment_t_stop(block_index, seg_index)
602+
603+
# rescale to integer sampling of time
604+
ts0 = int((ts0 + self.global_t_start) * 1e6)
605+
ts1 = int((ts1 + self.global_t_start) * 1e6)
606+
601607

602608
# only count spikes inside the timestamp limits, inclusive, and for the specified unit
603609
keep = (ts >= ts0) & (ts <= ts1) & (unit_id == data['unit_id'])
@@ -612,11 +618,15 @@ def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_s
612618
data = self._spike_memmap[chan_uid]
613619
ts = data['timestamp']
614620

615-
ts0, ts1 = self._timestamp_limits[seg_index]
616-
if t_start is not None:
617-
ts0 = int((t_start + self.global_t_start) * 1e6)
618-
if t_start is not None:
619-
ts1 = int((t_stop + self.global_t_start) * 1e6)
621+
ts0, ts1 = t_start, t_stop
622+
if ts0 is None:
623+
ts0 = self.segment_t_start(block_index, seg_index)
624+
if ts1 is None:
625+
ts1 = self.segment_t_stop(block_index, seg_index)
626+
627+
# rescale to integer sampling of time
628+
ts0 = int((ts0 + self.global_t_start) * 1e6)
629+
ts1 = int((ts1 + self.global_t_start) * 1e6)
620630

621631
keep = (ts >= ts0) & (ts <= ts1) & (unit_id == data['unit_id'])
622632
timestamps = ts[keep]
@@ -628,17 +638,20 @@ def _rescale_spike_timestamp(self, spike_timestamps, dtype):
628638
spike_times -= self.global_t_start
629639
return spike_times
630640

631-
def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index,
632-
t_start, t_stop):
641+
def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop):
633642
chan_uid, unit_id = self.internal_unit_ids[unit_index]
634643
data = self._spike_memmap[chan_uid]
635644
ts = data['timestamp']
636645

637-
ts0, ts1 = self._timestamp_limits[seg_index]
638-
if t_start is not None:
639-
ts0 = int((t_start + self.global_t_start) * 1e6)
640-
if t_start is not None:
641-
ts1 = int((t_stop + self.global_t_start) * 1e6)
646+
ts0, ts1 = t_start, t_stop
647+
if ts0 is None:
648+
ts0 = self.segment_t_start(block_index, seg_index)
649+
if ts1 is None:
650+
ts1 = self.segment_t_stop(block_index, seg_index)
651+
652+
# rescale to integer sampling of time
653+
ts0 = int((ts0 + self.global_t_start) * 1e6)
654+
ts1 = int((ts1 + self.global_t_start) * 1e6)
642655

643656
keep = (ts >= ts0) & (ts <= ts1) & (unit_id == data['unit_id'])
644657

@@ -656,7 +669,14 @@ def _event_count(self, block_index, seg_index, event_channel_index):
656669
event_id, ttl_input = self.internal_event_ids[event_channel_index]
657670
chan_id = self.header['event_channels'][event_channel_index]['id']
658671
data = self._nev_memmap[chan_id]
659-
ts0, ts1 = self._timestamp_limits[seg_index]
672+
673+
ts0 = self.segment_t_start(block_index, seg_index)
674+
ts1 = self.segment_t_stop(block_index, seg_index)
675+
676+
# rescale to integer sampling of time
677+
ts0 = int((ts0 + self.global_t_start) * 1e6)
678+
ts1 = int((ts1 + self.global_t_start) * 1e6)
679+
660680
ts = data['timestamp']
661681
keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) & \
662682
(data['ttl_input'] == ttl_input)
@@ -667,12 +687,16 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s
667687
event_id, ttl_input = self.internal_event_ids[event_channel_index]
668688
chan_id = self.header['event_channels'][event_channel_index]['id']
669689
data = self._nev_memmap[chan_id]
670-
ts0, ts1 = self._timestamp_limits[seg_index]
671690

672-
if t_start is not None:
673-
ts0 = int((t_start + self.global_t_start) * 1e6)
674-
if t_start is not None:
675-
ts1 = int((t_stop + self.global_t_start) * 1e6)
691+
ts0, ts1 = t_start, t_stop
692+
if ts0 is None:
693+
ts0 = self.segment_t_start(block_index, seg_index)
694+
if ts1 is None:
695+
ts1 = self.segment_t_stop(block_index, seg_index)
696+
697+
# rescale to integer sampling of time
698+
ts0 = int((ts0 + self.global_t_start) * 1e6)
699+
ts1 = int((ts1 + self.global_t_start) * 1e6)
676700

677701
ts = data['timestamp']
678702
keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) & \

neo/test/iotest/test_neuralynxio.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,15 @@ def test_incomplete_block_handling_v632(self):
331331
self.assertEqual(len(block.segments), n_gaps + 1)
332332
# self.assertEqual(len(block.channel_indexes[0].analogsignals), n_gaps + 1)
333333

334-
for t, gt in zip(nio._ncs_seg_timestamp_limits.t_start, [8408.806811, 8427.832053,
335-
8487.768561]):
336-
self.assertEqual(np.round(t, 4), np.round(gt, 4))
337-
for t, gt in zip(nio._ncs_seg_timestamp_limits.t_stop, [8427.831990, 8487.768498,
338-
8515.816549]):
339-
self.assertEqual(np.round(t, 4), np.round(gt, 4))
334+
expected_segment_starts = [8124.582909, 8427.832053, 8487.768561]
335+
expected_segment_stops = [8427.831990, 8487.768498, 10794.133994]
336+
for seg_idx in range(n_gaps+1):
337+
338+
t_start = nio.segment_t_start(0, seg_idx) + nio.global_t_start
339+
t_stop = nio.segment_t_stop(0, seg_idx) + nio.global_t_start
340+
341+
self.assertEqual(np.round(t_start, 4), np.round(expected_segment_starts[seg_idx], 4))
342+
self.assertEqual(np.round(t_stop, 4), np.round(expected_segment_stops[seg_idx], 4))
340343

341344

342345
class TestGaps(CommonNeuralynxIOTest, unittest.TestCase):

0 commit comments

Comments
 (0)