Skip to content

Commit 421b5ed

Browse files
sprengerJuliaSprenger
authored andcommitted
[Neuralynx] introduce 2-level section checking: within stream and across streams
1 parent 534c3b0 commit 421b5ed

File tree

1 file changed

+126
-48
lines changed

1 file changed

+126
-48
lines changed

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 126 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
from ..baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
4646
_spike_channel_dtype, _event_channel_dtype)
47-
47+
from operator import itemgetter
4848
import numpy as np
4949
import os
5050
import pathlib
@@ -163,8 +163,8 @@ def _parse_header(self):
163163
if excl_file in filenames:
164164
filenames.remove(excl_file)
165165

166-
ncs_sampling_rates = []
167-
stream_id = -1 # will be increased to 0 for first signal
166+
stream_props = {} # {(sampling_rate, n_samples, t_start):
167+
# {stream_id: [filenames]}
168168

169169
for filename in filenames:
170170
filename = os.path.join(dirname, filename)
@@ -191,9 +191,19 @@ def _parse_header(self):
191191

192192
chan_uid = (chan_name, str(chan_id))
193193
if ext == 'ncs':
194-
if info['sampling_rate'] not in ncs_sampling_rates:
195-
ncs_sampling_rates.append(info['sampling_rate'])
196-
stream_id += 1
194+
file_mmap = self._get_file_map(filename)
195+
n_packets = copy.copy(file_mmap.shape[0])
196+
if n_packets:
197+
t_start = copy.copy(file_mmap[0][0])
198+
else: # empty file
199+
t_start = 0
200+
stream_prop = (info['sampling_rate'], n_packets, t_start)
201+
if stream_prop not in stream_props:
202+
stream_props[stream_prop] = {'stream_id': len(stream_props),
203+
'filenames': [filename]}
204+
else:
205+
stream_props[stream_prop]['filenames'].append(filename)
206+
stream_id = stream_props[stream_prop]['stream_id']
197207

198208
# a sampled signal channel
199209
units = 'uV'
@@ -285,10 +295,11 @@ def _parse_header(self):
285295
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
286296
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
287297

288-
# require all sampled signals, ncs files, to have the same sampling rate
289298
if signal_channels.size > 0:
290-
names = [f'signals ({sr}Hz)' for sr in ncs_sampling_rates]
291-
signal_streams = list(zip(names, range(len(names))))
299+
names = [f'Stream with (sampling_rate, n_packets, t_start): ' \
300+
f'({stream_prop})' for stream_prop in stream_props]
301+
ids = [stream_prop['stream_id'] for stream_prop in stream_props.values()]
302+
signal_streams = list(zip(names, ids))
292303
else:
293304
signal_streams = []
294305
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
@@ -298,14 +309,81 @@ def _parse_header(self):
298309
self._timestamp_limits = None
299310
self._nb_segment = 1
300311

301-
# Read ncs files for gap detection and nb_segment computation.
302-
self._sigs_memmaps, ncsSegTimestampLimits = self.scan_ncs_files(self.ncs_filenames)
303-
if ncsSegTimestampLimits:
304-
self._ncs_seg_timestamp_limits = ncsSegTimestampLimits # save copy
305-
self._nb_segment = ncsSegTimestampLimits.nb_segment
306-
self._timestamp_limits = ncsSegTimestampLimits.timestamp_limits.copy()
307-
self._sigs_t_start = ncsSegTimestampLimits.t_start.copy()
308-
self._sigs_t_stop = ncsSegTimestampLimits.t_stop.copy()
312+
stream_infos = {}
313+
314+
# Read ncs files of each stream for gap detection and nb_segment computation.
315+
for stream_id in np.unique(signal_channels['stream_id']):
316+
stream_channels = signal_channels[signal_channels['stream_id'] == stream_id]
317+
stream_chan_uids = zip(stream_channels['name'], stream_channels['id'])
318+
stream_filenames = itemgetter(*stream_chan_uids)(self.ncs_filenames)
319+
_sigs_memmaps, ncsSegTimestampLimits, section_structure = self.scan_stream_ncs_files(stream_filenames)
320+
321+
stream_infos[stream_id] = {'segment_sig_memmaps': _sigs_memmaps,
322+
'ncs_segment_infos': ncsSegTimestampLimits,
323+
'section_structure': section_structure}
324+
325+
# check if section structure across streams is compatible and merge infos
326+
ref_stream_id = list(stream_infos.keys())[0]
327+
ref_sec_structure = stream_infos[ref_stream_id]['section_structure']
328+
for stream_id, stream_info in stream_infos.items():
329+
sec_structure = stream_info['section_structure']
330+
331+
# check if section structure of streams are compatible
332+
# using tolerance of one data packet (512 samples)
333+
tolerance = 512 / min(ref_sec_structure.sampFreqUsed,
334+
sec_structure.sampFreqUsed) * 1e6
335+
if not ref_sec_structure.is_equivalent(sec_structure, abs_tol=tolerance):
336+
ref_chan_ids = signal_channels[signal_channels['stream_id'] == ref_stream_id]['name']
337+
chan_ids = signal_channels[signal_channels['stream_id'] == stream_id]['name']
338+
339+
raise ValueError('Incompatible section structures across streams: '
340+
f'Stream id {ref_stream_id}:{ref_chan_ids} and '
341+
f'{stream_id}:{chan_ids}.')
342+
343+
self._nb_segment = len(ref_sec_structure.sects)
344+
345+
# merge stream mmemmaps since streams are compatible
346+
self._sigs_memmaps = [{} for seg_idx in range(self._nb_segment)]
347+
self._timestamp_limits = [(None, None) for seg_idx in range(self._nb_segment)]
348+
self._signal_limits = [(None, None) for seg_idx in range(self._nb_segment)]
349+
for stream_id, stream_info in stream_infos.items():
350+
stream_mmaps = stream_info['segment_sig_memmaps']
351+
for seg_idx, signal_dict in enumerate(stream_mmaps):
352+
self._sigs_memmaps[seg_idx].update(signal_dict)
353+
354+
ncs_segment_info = stream_info['ncs_segment_infos']
355+
for seg_idx, (t_start, t_stop) in enumerate(ncs_segment_info.timestamp_limits):
356+
old_times = self._timestamp_limits[seg_idx]
357+
if (old_times[0] is None) or (t_start < old_times[0]):
358+
self._timestamp_limits[seg_idx] = (t_start, self._signal_limits[seg_idx][1])
359+
if (self._timestamp_limits[seg_idx][1] is None) or (
360+
t_stop > self._timestamp_limits[seg_idx][1]):
361+
self._timestamp_limits[seg_idx] = (self._signal_limits[seg_idx][0],
362+
t_stop)
363+
364+
for seg_idx in range(ncs_segment_info.nb_segment):
365+
t_start = ncs_segment_info.t_start[seg_idx]
366+
t_stop = ncs_segment_info.t_stop[seg_idx]
367+
old_times = self._signal_limits[seg_idx]
368+
if (self._signal_limits[seg_idx][0] is None) or (
369+
t_start < self._signal_limits[seg_idx][0]):
370+
self._signal_limits[seg_idx] = (t_start, self._signal_limits[seg_idx][1])
371+
if (self._signal_limits[seg_idx][1] is None) or (
372+
t_stop > self._signal_limits[seg_idx][1]):
373+
self._signal_limits[seg_idx] = (self._signal_limits[seg_idx][0],
374+
t_stop)
375+
376+
# self._sigs_length = [{} for seg_idx in range(self._nb_segment)]
377+
# for stream_id, stream_info in stream_infos.items():
378+
# ncs_segment_info = stream_info['ncs_segment_infos']
379+
# chan_ids = signal_channels[signal_channels['stream_id'] == stream_id]['name']
380+
#
381+
# for chan_uid in chan_ids:
382+
#
383+
#
384+
# for seg_idx in range(self._nb_segment):
385+
386+
309387

310388
# precompute signal lengths within segments
311389
self._sigs_length = []
@@ -338,18 +416,18 @@ def _parse_header(self):
338416
self.global_t_stop = ts1 / 1e6
339417
elif ts0 is not None:
340418
# case HAVE ncs AND HAVE nev or nse
341-
self.global_t_start = min(ts0 / 1e6, self._sigs_t_start[0])
342-
self.global_t_stop = max(ts1 / 1e6, self._sigs_t_stop[-1])
343-
self._seg_t_starts = list(self._sigs_t_start)
419+
self.global_t_start = min(ts0, self._timestamp_limits[0][0]) /1e6
420+
self.global_t_stop = max(ts1 / 1e6, self._timestamp_limits[-1][-1])
421+
self._seg_t_starts = [limits[0] /1e6 for limits in self._timestamp_limits]
344422
self._seg_t_starts[0] = self.global_t_start
345-
self._seg_t_stops = list(self._sigs_t_stop)
423+
self._seg_t_stops = [limits[1] / 1e6 for limits in self._timestamp_limits]
346424
self._seg_t_stops[-1] = self.global_t_stop
347425
else:
348426
# case HAVE ncs but NO nev or nse
349-
self._seg_t_starts = self._sigs_t_start
350-
self._seg_t_stops = self._sigs_t_stop
351-
self.global_t_start = self._sigs_t_start[0]
352-
self.global_t_stop = self._sigs_t_stop[-1]
427+
self._seg_t_starts = [limits[0] / 1e6 for limits in self._timestamp_limits]
428+
self._seg_t_stops = [limits[1] / 1e6 for limits in self._timestamp_limits]
429+
self.global_t_start = self._signal_limits[0][0] / 1e6
430+
self.global_t_stop = self._signal_limits[-1][-1] / 1e6
353431

354432
if self.keep_original_times:
355433
self.global_t_stop = self.global_t_stop - self.global_t_start
@@ -603,9 +681,11 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
603681
event_times -= self.global_t_start
604682
return event_times
605683

606-
def scan_ncs_files(self, ncs_filenames):
684+
def scan_stream_ncs_files(self, ncs_filenames):
607685
"""
608686
Given a list of ncs files, read their basic structure.
687+
Ncs files have to have common sampling_rate, number of packets and t_start
688+
(be part of a single stream)
609689
610690
PARAMETERS:
611691
------
@@ -617,60 +697,56 @@ def scan_ncs_files(self, ncs_filenames):
617697
[ {} for seg_index in range(self._nb_segment) ][chan_uid]
618698
seg_time_limits
619699
SegmentTimeLimits for sections in scanned Ncs files
700+
section_structure
701+
Section structure common to the ncs files
620702
621703
Files will be scanned to determine the sections of records. If file is a single
622704
section of records, this scan is brief, otherwise it will check each record which may
623705
take some time.
624706
"""
625707

626-
# :TODO: Needs to account for gaps and start and end times potentially
627-
# being different in different groups of channels. These groups typically
628-
# correspond to the channels collected by a single ADC card.
629708
if len(ncs_filenames) == 0:
630-
return None, None
709+
return None, None, None
631710

632711
# Build dictionary of chan_uid to associated NcsSections, memmap and NlxHeaders. Only
633712
# construct new NcsSections when it is different from that for the preceding file.
634713
chanSectMap = dict()
635-
for chan_uid, ncs_filename in self.ncs_filenames.items():
714+
sig_length = []
715+
for ncs_filename in ncs_filenames:
636716

637717
data = self._get_file_map(ncs_filename)
638718
nlxHeader = NlxHeader(ncs_filename)
639719

640720
if not chanSectMap or (chanSectMap and
641721
not NcsSectionsFactory._verifySectionsStructure(data, chan_ncs_sections)):
642722
chan_ncs_sections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader)
643-
chanSectMap[chan_uid] = [chan_ncs_sections, nlxHeader, ncs_filename]
723+
724+
# register file section structure for all contained channels
725+
for chan_uid in zip(nlxHeader['channel_names'],
726+
np.asarray(nlxHeader['channel_ids'], dtype=str)):
727+
chanSectMap[chan_uid] = [chan_ncs_sections, nlxHeader, ncs_filename]
728+
644729
del data
645730

646731
# Construct an inverse dictionary from NcsSections to list of associated chan_uids
647-
# consider channels
648-
revSectMap = {}
649-
for i, (k, v) in enumerate(chanSectMap.items()):
650-
if i == 0: # start initially with first Ncssections
651-
latest_sections = v[0]
652-
# time tolerance of +- one data package (in microsec)
653-
tolerance = 512 / min(v[0].sampFreqUsed, latest_sections.sampFreqUsed) * 1e6
654-
if v[0].is_equivalent(latest_sections, abs_tol=tolerance):
655-
revSectMap.setdefault(latest_sections, []).append(k)
656-
else:
657-
revSectMap[v[0]] = [k]
658-
latest_sections = v[0]
732+
revSectMap = dict()
733+
for k, v in chanSectMap.items():
734+
revSectMap.setdefault(v[0], []).append(k)
659735

660736
# If there is only one NcsSections structure in the set of ncs files, there should only
661737
# be one entry. Otherwise this is presently unsupported.
662738
if len(revSectMap) > 1:
663739
raise IOError(f'ncs files have {len(revSectMap)} different sections '
664-
f'structures. Unsupported configuration.')
740+
f'structures. Unsupported configuration to be handled with in a single '
741+
f'stream.')
665742

666743
seg_time_limits = SegmentTimeLimits(nb_segment=len(chan_ncs_sections.sects),
667744
t_start=[], t_stop=[], length=[],
668745
timestamp_limits=[])
669746
memmaps = [{} for seg_index in range(seg_time_limits.nb_segment)]
670747

671748
# create segment with subdata block/t_start/t_stop/length for each channel
672-
for i, fileEntry in enumerate(self.ncs_filenames.items()):
673-
chan_uid = fileEntry[0]
749+
for i, chan_uid in enumerate(chanSectMap.keys()):
674750
data = self._get_file_map(chanSectMap[chan_uid][2])
675751

676752
# create a memmap for each record section of the current file
@@ -699,7 +775,9 @@ def scan_ncs_files(self, ncs_filenames):
699775
length = (subdata.size - 1) * NcsSection._RECORD_SIZE + numSampsLastSect
700776
seg_time_limits.length.append(length)
701777

702-
return memmaps, seg_time_limits
778+
stream_section_structure = list(revSectMap.keys())[0]
779+
780+
return memmaps, seg_time_limits, stream_section_structure
703781

704782

705783
# time limits for set of segments

0 commit comments

Comments
 (0)