Skip to content

Commit 18e502e

Browse files
sprengerJuliaSprenger
authored andcommitted
[neuralynx] simplify segment time limit handling and streams
- adjust tests
1 parent 421b5ed commit 18e502e

File tree

2 files changed

+72
-75
lines changed

2 files changed

+72
-75
lines changed

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,13 @@
3434
event the gaps are larger, this RawIO only provides the samples from the first section as belonging
3535
to one Segment.
3636
37-
This RawIO presents only a single Block and Segment.
38-
:TODO: This should likely be changed to provide multiple segments and allow for
39-
multiple .Ncs files in a directory with differing section structures.
37+
This RawIO presents only a single Block.
4038
4139
Author: Julia Sprenger, Carlos Canova, Samuel Garcia, Peter N. Steinmetz.
4240
"""
4341

44-
4542
from ..baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
46-
_spike_channel_dtype, _event_channel_dtype)
43+
_spike_channel_dtype, _event_channel_dtype)
4744
from operator import itemgetter
4845
import numpy as np
4946
import os
@@ -89,8 +86,8 @@ def __init__(self, dirname='', filename='', exclude_filename=None, keep_original
8986
name of directory containing all files for dataset. If provided, filename is
9087
ignored.
9188
filename: str
92-
name of a single ncs, nse, nev, or ntt file to include in dataset. If used,
93-
dirname must not be provided.
89+
name of a single ncs, nse, nev, or ntt file to include in dataset. Will be ignored,
90+
if dirname is provided.
9491
exclude_filename: str or list
9592
name of a single ncs, nse, nev or ntt file or list of such files. Expects plain
9693
filenames (without directory path).
@@ -163,8 +160,7 @@ def _parse_header(self):
163160
if excl_file in filenames:
164161
filenames.remove(excl_file)
165162

166-
stream_props = {} # {(sampling_rate, n_samples, t_start):
167-
# {stream_id: [filenames]}
163+
stream_props = {} # {(sampling_rate, n_samples, t_start): {stream_id: [filenames]}
168164

169165
for filename in filenames:
170166
filename = os.path.join(dirname, filename)
@@ -213,7 +209,7 @@ def _parse_header(self):
213209
offset = 0.
214210
stream_id = stream_id
215211
signal_channels.append((chan_name, str(chan_id), info['sampling_rate'],
216-
'int16', units, gain, offset, stream_id))
212+
'int16', units, gain, offset, stream_id))
217213
self.ncs_filenames[chan_uid] = filename
218214
keys = [
219215
'DspFilterDelay_µs',
@@ -296,8 +292,7 @@ def _parse_header(self):
296292
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
297293

298294
if signal_channels.size > 0:
299-
names = [f'Stream with (sampling_rate, n_packets, t_start): ' \
300-
f'({stream_prop})' for stream_prop in stream_props]
295+
names = [f'Stream (rate,#packet,t0): {sp}' for sp in stream_props]
301296
ids = [stream_prop['stream_id'] for stream_prop in stream_props.values()]
302297
signal_streams = list(zip(names, ids))
303298
else:
@@ -315,36 +310,52 @@ def _parse_header(self):
315310
for stream_id in np.unique(signal_channels['stream_id']):
316311
stream_channels = signal_channels[signal_channels['stream_id'] == stream_id]
317312
stream_chan_uids = zip(stream_channels['name'], stream_channels['id'])
318-
stream_filenames = itemgetter(*stream_chan_uids)(self.ncs_filenames)
319-
_sigs_memmaps, ncsSegTimestampLimits, section_structure = self.scan_stream_ncs_files(stream_filenames)
313+
stream_filenames = [self.ncs_filenames[chuid] for chuid in stream_chan_uids]
314+
_sigs_memmaps, ncsSegTimestampLimits, section_structure = self.scan_stream_ncs_files(
315+
stream_filenames)
320316

321317
stream_infos[stream_id] = {'segment_sig_memmaps': _sigs_memmaps,
322318
'ncs_segment_infos': ncsSegTimestampLimits,
323319
'section_structure': section_structure}
324320

325321
# check if section structure across streams is compatible and merge infos
326-
ref_stream_id = list(stream_infos.keys())[0]
327-
ref_sec_structure = stream_infos[ref_stream_id]['section_structure']
322+
ref_sec_structure = None
328323
for stream_id, stream_info in stream_infos.items():
324+
ref_stream_id = list(stream_infos.keys())[0]
325+
ref_sec_structure = stream_infos[ref_stream_id]['section_structure']
326+
329327
sec_structure = stream_info['section_structure']
330328

331329
# check if section structure of streams are compatible
332330
# using tolerance of one data packet (512 samples)
333331
tolerance = 512 / min(ref_sec_structure.sampFreqUsed,
334332
sec_structure.sampFreqUsed) * 1e6
335333
if not ref_sec_structure.is_equivalent(sec_structure, abs_tol=tolerance):
336-
ref_chan_ids = signal_channels[signal_channels['stream_id'] == ref_stream_id]['name']
334+
ref_chan_ids = signal_channels[signal_channels['stream_id'] == ref_stream_id][
335+
'name']
337336
chan_ids = signal_channels[signal_channels['stream_id'] == stream_id]['name']
338337

339338
raise ValueError('Incompatible section structures across streams: '
340339
f'Stream id {ref_stream_id}:{ref_chan_ids} and '
341340
f'{stream_id}:{chan_ids}.')
342341

343-
self._nb_segment = len(ref_sec_structure.sects)
342+
if ref_sec_structure is not None:
343+
self._nb_segment = len(ref_sec_structure.sects)
344+
else:
345+
# Use only a single segment if no ncs data is present
346+
self._nb_segment = 1
347+
348+
def min_max_tuple(tuple1, tuple2):
349+
"""Merge tuple by selecting min for first and max for 2nd entry"""
350+
mins, maxs = zip(tuple1, tuple2)
351+
result = (min(m for m in mins if m is not None), max(m for m in maxs if m is not None))
352+
return result
344353

345354
# merge stream mmemmaps since streams are compatible
346355
self._sigs_memmaps = [{} for seg_idx in range(self._nb_segment)]
356+
# time limits of integer timestamps in ncs files
347357
self._timestamp_limits = [(None, None) for seg_idx in range(self._nb_segment)]
358+
# time limits physical times in ncs files
348359
self._signal_limits = [(None, None) for seg_idx in range(self._nb_segment)]
349360
for stream_id, stream_info in stream_infos.items():
350361
stream_mmaps = stream_info['segment_sig_memmaps']
@@ -353,37 +364,14 @@ def _parse_header(self):
353364

354365
ncs_segment_info = stream_info['ncs_segment_infos']
355366
for seg_idx, (t_start, t_stop) in enumerate(ncs_segment_info.timestamp_limits):
356-
old_times = self._timestamp_limits[seg_idx]
357-
if (old_times[0] is None) or (t_start < old_times[0]):
358-
self._timestamp_limits[seg_idx] = (t_start, self._signal_limits[seg_idx][1])
359-
if (self._timestamp_limits[seg_idx][1] is None) or (
360-
t_stop > self._timestamp_limits[seg_idx][1]):
361-
self._timestamp_limits[seg_idx] = (self._signal_limits[seg_idx][0],
362-
t_stop)
367+
self._timestamp_limits[seg_idx] = min_max_tuple(self._timestamp_limits[seg_idx],
368+
(t_start, t_stop))
363369

364370
for seg_idx in range(ncs_segment_info.nb_segment):
365371
t_start = ncs_segment_info.t_start[seg_idx]
366372
t_stop = ncs_segment_info.t_stop[seg_idx]
367-
old_times = self._signal_limits[seg_idx]
368-
if (self._signal_limits[seg_idx][0] is None) or (
369-
t_start < self._signal_limits[seg_idx][0]):
370-
self._signal_limits[seg_idx] = (t_start, self._signal_limits[seg_idx][1])
371-
if (self._signal_limits[seg_idx][1] is None) or (
372-
t_stop > self._signal_limits[seg_idx][1]):
373-
self._signal_limits[seg_idx] = (self._signal_limits[seg_idx][0],
374-
t_stop)
375-
376-
# self._sigs_length = [{} for seg_idx in range(self._nb_segment)]
377-
# for stream_id, stream_info in stream_infos.items():
378-
# ncs_segment_info = stream_info['ncs_segment_infos']
379-
# chan_ids = signal_channels[signal_channels['stream_id'] == stream_id]['name']
380-
#
381-
# for chan_uid in chan_ids:
382-
#
383-
#
384-
# for seg_idx in range(self._nb_segment):
385-
386-
373+
self._signal_limits[seg_idx] = min_max_tuple(self._signal_limits[seg_idx],
374+
(t_start, t_stop))
387375

388376
# precompute signal lengths within segments
389377
self._sigs_length = []
@@ -393,7 +381,7 @@ def _parse_header(self):
393381
for chan_uid, sig_infos in sig_container.items():
394382
self._sigs_length[seg_idx][chan_uid] = int(sig_infos['nb_valid'].sum())
395383

396-
# Determine timestamp limits in nev, nse file by scanning them.
384+
# Determine timestamp limits in nev, nse, ntt files by scanning them.
397385
ts0, ts1 = None, None
398386
for _data_memmap in (self._spike_memmap, self._nev_memmap):
399387
for _, data in _data_memmap.items():
@@ -406,28 +394,36 @@ def _parse_header(self):
406394
ts0 = min(ts0, ts[0])
407395
ts1 = max(ts1, ts[-1])
408396

397+
# rescaling for comparison with signal times
398+
if ts0 is not None:
399+
timestamps_start, timestamps_stop = ts0 / 1e6, ts1 / 1e6
400+
409401
# decide on segment and global start and stop times based on files available
410402
if self._timestamp_limits is None:
411-
# case NO ncs but HAVE nev or nse
403+
# case NO ncs but HAVE nev or nse -> single segment covering all spikes & events
412404
self._timestamp_limits = [(ts0, ts1)]
413-
self._seg_t_starts = [ts0 / 1e6]
414-
self._seg_t_stops = [ts1 / 1e6]
415-
self.global_t_start = ts0 / 1e6
416-
self.global_t_stop = ts1 / 1e6
405+
self._seg_t_starts = [timestamps_start]
406+
self._seg_t_stops = [timestamps_stop]
407+
self.global_t_start = timestamps_start
408+
self.global_t_stop = timestamps_stop
417409
elif ts0 is not None:
418-
# case HAVE ncs AND HAVE nev or nse
419-
self.global_t_start = min(ts0, self._timestamp_limits[0][0]) /1e6
420-
self.global_t_stop = max(ts1 / 1e6, self._timestamp_limits[-1][-1])
421-
self._seg_t_starts = [limits[0] /1e6 for limits in self._timestamp_limits]
410+
# case HAVE ncs AND HAVE nev or nse -> multi segments based on ncs segmentation
411+
# ignoring nev/nse/ntt time limits, loading only data within ncs segments
412+
global_events_limits = (timestamps_start, timestamps_stop)
413+
global_signal_limits = (self._signal_limits[0][0], self._signal_limits[-1][-1])
414+
self.global_t_start, self.global_t_stop = min_max_tuple(global_events_limits,
415+
global_signal_limits)
416+
self._seg_t_starts = [limits[0] for limits in self._signal_limits]
417+
self._seg_t_stops = [limits[1] for limits in self._signal_limits]
422418
self._seg_t_starts[0] = self.global_t_start
423-
self._seg_t_stops = [limits[1] / 1e6 for limits in self._timestamp_limits]
424419
self._seg_t_stops[-1] = self.global_t_stop
420+
425421
else:
426-
# case HAVE ncs but NO nev or nse
427-
self._seg_t_starts = [limits[0] / 1e6 for limits in self._timestamp_limits]
428-
self._seg_t_stops = [limits[1] / 1e6 for limits in self._timestamp_limits]
429-
self.global_t_start = self._signal_limits[0][0] / 1e6
430-
self.global_t_stop = self._signal_limits[-1][-1] / 1e6
422+
# case HAVE ncs but NO nev or nse ->
423+
self._seg_t_starts = [limits[0] for limits in self._signal_limits]
424+
self._seg_t_stops = [limits[1] for limits in self._signal_limits]
425+
self.global_t_start = self._signal_limits[0][0]
426+
self.global_t_stop = self._signal_limits[-1][-1]
431427

432428
if self.keep_original_times:
433429
self.global_t_stop = self.global_t_stop - self.global_t_start
@@ -491,7 +487,7 @@ def _get_file_map(self, filename):
491487

492488
if suffix == 'ncs':
493489
return np.memmap(filename, dtype=self._ncs_dtype, mode='r',
494-
offset=NlxHeader.HEADER_SIZE)
490+
offset=NlxHeader.HEADER_SIZE)
495491

496492
elif suffix in ['nse', 'ntt']:
497493
info = NlxHeader(filename)
@@ -599,6 +595,9 @@ def _spike_count(self, block_index, seg_index, unit_index):
599595
return nb_spike
600596

601597
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
598+
"""
599+
Extract timestamps within a Segment defined by ncs timestamps
600+
"""
602601
chan_uid, unit_id = self.internal_unit_ids[unit_index]
603602
data = self._spike_memmap[chan_uid]
604603
ts = data['timestamp']
@@ -718,7 +717,8 @@ def scan_stream_ncs_files(self, ncs_filenames):
718717
nlxHeader = NlxHeader(ncs_filename)
719718

720719
if not chanSectMap or (chanSectMap and
721-
not NcsSectionsFactory._verifySectionsStructure(data, chan_ncs_sections)):
720+
not NcsSectionsFactory._verifySectionsStructure(data,
721+
chan_ncs_sections)):
722722
chan_ncs_sections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader)
723723

724724
# register file section structure for all contained channels
@@ -784,7 +784,6 @@ def scan_stream_ncs_files(self, ncs_filenames):
784784
SegmentTimeLimits = namedtuple("SegmentTimeLimits", ['nb_segment', 't_start', 't_stop', 'length',
785785
'timestamp_limits'])
786786

787-
788787
nev_dtype = [
789788
('reserved', '<i2'),
790789
('system_id', '<i2'),

neo/test/rawiotest/test_neuralynxrawio.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def test_scan_ncs_files(self):
4040
self.assertEqual(rawio._nb_segment, 1)
4141
self.assertListEqual(rawio._timestamp_limits, [(0, 192000)])
4242
self.assertEqual(rawio._sigs_length[0][('unknown', '1')], 4608)
43-
self.assertEqual(rawio._sigs_t_start[0], 0)
44-
self.assertEqual(rawio._sigs_t_stop[0], 0.192)
43+
self.assertListEqual(rawio._signal_limits, [(0, 0.192)])
4544
self.assertEqual(len(rawio._sigs_memmaps), 1)
4645

4746
# Test Cheetah 4.0.2, which is PRE4 type with frequency in header and
@@ -54,8 +53,7 @@ def test_scan_ncs_files(self):
5453
self.assertEqual(rawio.signal_streams_count(), 1)
5554
self.assertListEqual(rawio._timestamp_limits, [(266982936, 267162136)])
5655
self.assertEqual(rawio._sigs_length[0][('unknown', '13')], 5120)
57-
self.assertEqual(rawio._sigs_t_start[0], 266.982936)
58-
self.assertEqual(rawio._sigs_t_stop[0], 267.162136)
56+
self.assertListEqual(rawio._signal_limits, [(266.982936,267.162136)])
5957
self.assertEqual(len(rawio._sigs_memmaps), 1)
6058

6159
# Test Cheetah 5.5.1, which is DigitalLynxSX and has two blocks of records
@@ -71,8 +69,8 @@ def test_scan_ncs_files(self):
7169
{('Tet3a', '8'): 1278976, ('Tet3b', '9'): 1278976})
7270
self.assertDictEqual(rawio._sigs_length[1],
7371
{('Tet3a', '8'): 427008, ('Tet3b', '9'): 427008})
74-
self.assertListEqual(rawio._sigs_t_stop, [26162.525633, 26379.704633])
75-
self.assertListEqual(rawio._sigs_t_start, [26122.557633, 26366.360633])
72+
self.assertListEqual(rawio._signal_limits, [(26122.557633, 26162.525633),
73+
(26366.360633, 26379.704633)])
7674
self.assertEqual(len(rawio._sigs_memmaps), 2) # check only that there are 2 memmaps
7775

7876
# Test Cheetah 6.3.2, the incomplete_blocks test. This is a DigitalLynxSX with
@@ -89,8 +87,9 @@ def test_scan_ncs_files(self):
8987
self.assertDictEqual(rawio._sigs_length[0], {('CSC1', '48'): 608806})
9088
self.assertDictEqual(rawio._sigs_length[1], {('CSC1', '48'): 1917967})
9189
self.assertDictEqual(rawio._sigs_length[2], {('CSC1', '48'): 897536})
92-
self.assertListEqual(rawio._sigs_t_stop, [8427.831990, 8487.768498, 8515.816549])
93-
self.assertListEqual(rawio._sigs_t_start, [8408.806811, 8427.832053, 8487.768561])
90+
self.assertListEqual(rawio._signal_limits, [(8408.806811, 8427.831990),
91+
(8427.832053, 8487.768498),
92+
(8487.768561, 8515.816549)])
9493
self.assertEqual(len(rawio._sigs_memmaps), 3) # check that there are only 3 memmaps
9594

9695
# Test Cheetah 6.4.1, with different sampling rates across ncs files.
@@ -101,12 +100,11 @@ def test_scan_ncs_files(self):
101100
seg_idx = 0
102101

103102
self.assertEqual(rawio.signal_streams_count(), 3)
104-
self.assertListEqual(rawio._timestamp_limits, [(1614363777985169, 1614363778481169)])
103+
self.assertListEqual(rawio._timestamp_limits, [(1614363777825263, 1614363778481169)])
105104
self.assertDictEqual(rawio._sigs_length[seg_idx], {('CSC1', '26'): 15872,
106105
('LFP4', '41'): 1024,
107106
('WE1', '33'): 512})
108-
self.assertListEqual(rawio._sigs_t_stop, [1614363778.481169])
109-
self.assertListEqual(rawio._sigs_t_start, [1614363777.985169])
107+
self.assertListEqual(rawio._signal_limits, [(1614363777.825263, 1614363778.481169)])
110108
# check that there are only 3 memmaps
111109
self.assertEqual(len(rawio._sigs_memmaps[seg_idx]), 3)
112110

0 commit comments

Comments
 (0)