Skip to content

Commit 877123a

Browse files
sprengerJuliaSprenger
authored andcommitted
[Neuralynx] adjust test for streams and precalculate signal length
1 parent da4c04e commit 877123a

File tree

4 files changed

+42
-29
lines changed

4 files changed

+42
-29
lines changed

neo/rawio/neuralynxrawio/ncssections.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,9 @@ def _buildGivenActualFrequency(ncsMemMap, actualSampFreq, reqFreq):
228228
raise IOError("Sampling frequency in first record doesn't agree with header.")
229229
chanNum = ncsMemMap['channel_id'][0]
230230

231-
nb = NcsSections()
232-
nb.sampFreqUsed = actualSampFreq
233-
nb.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq(actualSampFreq)
231+
secs = NcsSections()
232+
secs.sampFreqUsed = actualSampFreq
233+
secs.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq(actualSampFreq)
234234

235235
# check if file is one block of records, which is often the case, and avoid full parse
236236
lastBlkI = ncsMemMap.shape[0] - 1
@@ -248,15 +248,15 @@ def _buildGivenActualFrequency(ncsMemMap, actualSampFreq, reqFreq):
248248
n_samples = NcsSection._RECORD_SIZE * lastBlkI
249249
curBlock = NcsSection(0, ts0, lastBlkI, lastBlkEndTime, n_samples)
250250

251-
nb.sects.append(curBlock)
252-
return nb
251+
secs.sects.append(curBlock)
252+
return secs
253253

254254
# otherwise need to scan looking for breaks
255255
else:
256256
blkOnePredTime = NcsSectionsFactory.calc_sample_time(actualSampFreq, ts0, nb0)
257257
curBlock = NcsSection(0, ts0, -1, -1, -1)
258-
nb.sects.append(curBlock)
259-
return NcsSectionsFactory._parseGivenActualFrequency(ncsMemMap, nb, chanNum, reqFreq,
258+
secs.sects.append(curBlock)
259+
return NcsSectionsFactory._parseGivenActualFrequency(ncsMemMap, secs, chanNum, reqFreq,
260260
blkOnePredTime)
261261

262262
@staticmethod

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,6 @@ def _source_name(self):
117117
else:
118118
return self.dirname
119119

120-
# from memory_profiler import profile
121-
#
122-
# @profile()
123120
def _parse_header(self):
124121

125122
stream_channels = []
@@ -192,7 +189,7 @@ def _parse_header(self):
192189
for idx, chan_id in enumerate(chan_ids):
193190
chan_name = chan_names[idx]
194191

195-
chan_uid = (chan_name, chan_id)
192+
chan_uid = (chan_name, str(chan_id))
196193
if ext == 'ncs':
197194
if info['sampling_rate'] not in ncs_sampling_rates:
198195
ncs_sampling_rates.append(info['sampling_rate'])
@@ -310,6 +307,14 @@ def _parse_header(self):
310307
self._sigs_t_start = ncsSegTimestampLimits.t_start.copy()
311308
self._sigs_t_stop = ncsSegTimestampLimits.t_stop.copy()
312309

310+
# precompute signal lengths within segments
311+
self._sigs_length = []
312+
if self._sigs_memmaps:
313+
for seg_idx, sig_container in enumerate(self._sigs_memmaps):
314+
self._sigs_length.append({})
315+
for chan_uid, sig_infos in sig_container.items():
316+
self._sigs_length[seg_idx][chan_uid] = int(sig_infos['nb_valid'].sum())
317+
313318
# Determine timestamp limits in nev, nse file by scanning them.
314319
ts0, ts1 = None, None
315320
for _data_memmap in (self._spike_memmap, self._nev_memmap):
@@ -438,12 +443,12 @@ def _segment_t_stop(self, block_index, seg_index):
438443

439444
def _get_signal_size(self, block_index, seg_index, stream_index):
440445
stream_id = self.header['signal_streams'][stream_index]['id']
441-
channel_indexes = np.where(self.header['signal_channels']['stream_id'] == stream_id)[0]
446+
stream_mask = self.header['signal_channels']['stream_id'] == stream_id
447+
signals = self.header['signal_channels'][stream_mask]
442448

443-
if len(channel_indexes):
444-
channel_index = channel_indexes[0]
445-
chan = self.header['signal_channels'][channel_index]
446-
return self._sigs_memmaps[seg_index][(chan['name'], int(chan['id']))]
449+
if len(signals):
450+
sig = signals[0]
451+
return self._sigs_length[seg_index][(sig['name'], sig['id'])]
447452
else:
448453
return None
449454

@@ -475,7 +480,8 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
475480
if i_start is None:
476481
i_start = 0
477482
if i_stop is None:
478-
i_stop = self._sigs_length[seg_index]
483+
i_stop = self.get_signal_size(block_index=block_index, seg_index=seg_index,
484+
stream_index=stream_index)
479485

480486
block_start = i_start // NcsSection._RECORD_SIZE
481487
block_stop = i_stop // NcsSection._RECORD_SIZE + 1
@@ -489,7 +495,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
489495
stream_mask = self.header['signal_channels']['stream_id'] == stream_id
490496

491497
# channel_streams = self.
492-
channel_ids = self.header['signal_channels'][stream_mask][channel_indexes]['id'].astype(int)
498+
channel_ids = self.header['signal_channels'][stream_mask][channel_indexes]['id']
493499
channel_names = self.header['signal_channels'][stream_mask][channel_indexes]['name']
494500

495501
# create buffer for samples
@@ -632,10 +638,9 @@ def scan_ncs_files(self, ncs_filenames):
632638
nlxHeader = NlxHeader(ncs_filename)
633639

634640
if not chanSectMap or (chanSectMap and
635-
not NcsSectionsFactory._verifySectionsStructure(data,
636-
lastNcsSections)):
637-
lastNcsSections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader)
638-
chanSectMap[chan_uid] = [lastNcsSections, nlxHeader, ncs_filename]
641+
not NcsSectionsFactory._verifySectionsStructure(data, chan_ncs_sections)):
642+
chan_ncs_sections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader)
643+
chanSectMap[chan_uid] = [chan_ncs_sections, nlxHeader, ncs_filename]
639644
del data
640645

641646
# Construct an inverse dictionary from NcsSections to list of associated chan_uids
@@ -646,7 +651,7 @@ def scan_ncs_files(self, ncs_filenames):
646651
latest_sections = v[0]
647652
# time tolerance of +- one data package (in microsec)
648653
tolerance = 512 / min(v[0].sampFreqUsed, latest_sections.sampFreqUsed) *1e6
649-
if v[0].is_equivalent(latest_sections, abs_tol=tolerance):
654+
if not v[0].is_equivalent(latest_sections, abs_tol=tolerance):
650655
revSectMap.setdefault(latest_sections, []).append(k)
651656
else:
652657
revSectMap[v[0]] = [k]
@@ -658,7 +663,7 @@ def scan_ncs_files(self, ncs_filenames):
658663
raise IOError(f'ncs files have {len(revSectMap)} different sections '
659664
f'structures. Unsupported configuration.')
660665

661-
seg_time_limits = SegmentTimeLimits(nb_segment=len(lastNcsSections.sects),
666+
seg_time_limits = SegmentTimeLimits(nb_segment=len(chan_ncs_sections.sects),
662667
t_start=[], t_stop=[], length=[],
663668
timestamp_limits=[])
664669
memmaps = [{} for seg_index in range(seg_time_limits.nb_segment)]

neo/test/iotest/test_neuralynxio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_ncs(self):
275275

276276
# check that data agrees in first segment first channel only
277277
for anasig_id, anasig in enumerate(block.segments[0].analogsignals):
278-
chid = int(anasig.array_annotations['channel_ids'][0])
278+
chid = anasig.array_annotations['channel_ids'][0]
279279

280280
chname = str(anasig.array_annotations['channel_names'][0])
281281
chuid = (chname, chid)

neo/test/rawiotest/test_neuralynxrawio.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_scan_ncs_files(self):
3939
# test values here from direct inspection of .ncs files
4040
self.assertEqual(rawio._nb_segment, 1)
4141
self.assertListEqual(rawio._timestamp_limits, [(0, 192000)])
42-
self.assertEqual(rawio._sigs_length[0], 4608)
42+
self.assertEqual(rawio._sigs_length[0][('unknown', '1')], 4608)
4343
self.assertEqual(rawio._sigs_t_start[0], 0)
4444
self.assertEqual(rawio._sigs_t_stop[0], 0.192)
4545
self.assertEqual(len(rawio._sigs_memmaps), 1)
@@ -51,8 +51,9 @@ def test_scan_ncs_files(self):
5151
rawio.parse_header()
5252
# test values here from direct inspection of .ncs files
5353
self.assertEqual(rawio._nb_segment, 1)
54+
self.assertEqual(rawio.signal_streams_count(), 1)
5455
self.assertListEqual(rawio._timestamp_limits, [(266982936, 267162136)])
55-
self.assertEqual(rawio._sigs_length[0], 5120)
56+
self.assertEqual(rawio._sigs_length[0][('unknown', '13')], 5120)
5657
self.assertEqual(rawio._sigs_t_start[0], 266.982936)
5758
self.assertEqual(rawio._sigs_t_stop[0], 267.162136)
5859
self.assertEqual(len(rawio._sigs_memmaps), 1)
@@ -63,9 +64,13 @@ def test_scan_ncs_files(self):
6364
rawio.parse_header()
6465
# test values here from direct inspection of .ncs files
6566
self.assertEqual(rawio._nb_segment, 2)
67+
self.assertEqual(rawio.signal_streams_count(), 1)
6668
self.assertListEqual(rawio._timestamp_limits, [(26122557633, 26162525633),
6769
(26366360633, 26379704633)])
68-
self.assertListEqual(rawio._sigs_length, [1278976, 427008])
70+
self.assertDictEqual(rawio._sigs_length[0],
71+
{('Tet3a', '8'): 1278976, ('Tet3b', '9'): 1278976})
72+
self.assertDictEqual(rawio._sigs_length[1],
73+
{('Tet3a', '8'): 427008, ('Tet3b', '9'): 427008})
6974
self.assertListEqual(rawio._sigs_t_stop, [26162.525633, 26379.704633])
7075
self.assertListEqual(rawio._sigs_t_start, [26122.557633, 26366.360633])
7176
self.assertEqual(len(rawio._sigs_memmaps), 2) # check only that there are 2 memmaps
@@ -77,10 +82,13 @@ def test_scan_ncs_files(self):
7782
# test values here from direct inspection of .ncs file, except for 3rd block
7883
# t_stop, which is extended due to events past the last block of ncs records.
7984
self.assertEqual(rawio._nb_segment, 3)
85+
self.assertEqual(rawio.signal_streams_count(), 1)
8086
self.assertListEqual(rawio._timestamp_limits, [(8408806811, 8427831990),
8187
(8427832053, 8487768498),
8288
(8487768561, 8515816549)])
83-
self.assertListEqual(rawio._sigs_length, [608806, 1917967, 897536])
89+
self.assertDictEqual(rawio._sigs_length[0], {('CSC1', '48'): 608806})
90+
self.assertDictEqual(rawio._sigs_length[1], {('CSC1', '48'): 1917967})
91+
self.assertDictEqual(rawio._sigs_length[2], {('CSC1', '48'): 897536})
8492
self.assertListEqual(rawio._sigs_t_stop, [8427.831990, 8487.768498, 8515.816549])
8593
self.assertListEqual(rawio._sigs_t_start, [8408.806811, 8427.832053, 8487.768561])
8694
self.assertEqual(len(rawio._sigs_memmaps), 3) # check only that there are 3 memmaps

0 commit comments

Comments
 (0)