Skip to content

Commit b384214

Browse files
authored
Merge pull request #990 from JuliaSprenger/fix/neuralynx_mem
[neuralynx] memory improvements
2 parents 4a4514a + d76ea1f commit b384214

File tree

3 files changed

+194
-83
lines changed

3 files changed

+194
-83
lines changed

neo/rawio/neuralynxrawio/ncssections.py

Lines changed: 106 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import math
2+
import numpy as np
23

34

45
class NcsSections:
@@ -7,7 +8,7 @@ class NcsSections:
78
Methods of NcsSectionsFactory perform parsing of this information from an Ncs file and
89
produce these where the sections are discontiguous in time and in temporal order.
910
10-
TODO: This class will likely need __eq__, __ne__, and __hash__ to be useful in
11+
TODO: This class will likely need __ne__ to be useful in
1112
more sophisticated segment construction algorithms.
1213
1314
"""
@@ -16,6 +17,16 @@ def __init__(self):
1617
self.sampFreqUsed = 0 # actual sampling frequency of samples
1718
self.microsPerSampUsed = 0 # microseconds per sample
1819

20+
def __eq__(self, other):
21+
samp_eq = self.sampFreqUsed == other.sampFreqUsed
22+
micros_eq = self.microsPerSampUsed == other.microsPerSampUsed
23+
sects_eq = self.sects == other.sects
24+
return (samp_eq and micros_eq and sects_eq)
25+
26+
def __hash__(self):
27+
return (f'{self.sampFreqUsed};{self.microsPerSampUsed};'
28+
f'{[s.__hash__() for s in self.sects]}').__hash__()
29+
1930

2031
class NcsSection:
2132
"""
@@ -37,11 +48,23 @@ def __init__(self):
3748
self.endTime = -1 # end time of last record, that is, the end time of the last
3849
# sampling period contained in the last record of the section
3950

40-
def __init__(self, sb, st, eb, et):
51+
def __init__(self, sb, st, eb, et, ns):
4152
self.startRec = sb
4253
self.startTime = st
4354
self.endRec = eb
4455
self.endTime = et
56+
self.n_samples = ns
57+
58+
def __eq__(self, other):
59+
return (self.startRec == other.startRec
60+
and self.startTime == other.startTime
61+
and self.endRec == other.endRec
62+
and self.endTime == other.endTime
63+
and self.n_samples == other.n_samples)
64+
65+
def __hash__(self):
66+
s = f'{self.startRec};{self.startTime};{self.endRec};{self.endTime};{self.n_samples}'
67+
return s.__hash__()
4568

4669
def before_time(self, rhb):
4770
"""
@@ -124,32 +147,38 @@ def _parseGivenActualFrequency(ncsMemMap, ncsSects, chanNum, reqFreq, blkOnePred
124147
NcsSections object with block locations marked
125148
"""
126149
startBlockPredTime = blkOnePredTime
127-
blkLen = 0
150+
blk_len = 0
128151
curBlock = ncsSects.sects[0]
129152
for recn in range(1, ncsMemMap.shape[0]):
130-
if ncsMemMap['channel_id'][recn] != chanNum or \
131-
ncsMemMap['sample_rate'][recn] != reqFreq:
153+
timestamp = ncsMemMap['timestamp'][recn]
154+
channel_id = ncsMemMap['channel_id'][recn]
155+
sample_rate = ncsMemMap['sample_rate'][recn]
156+
nb_valid = ncsMemMap['nb_valid'][recn]
157+
158+
if channel_id != chanNum or sample_rate != reqFreq:
132159
raise IOError('Channel number or sampling frequency changed in ' +
133160
'records within file')
134161
predTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed,
135-
startBlockPredTime, blkLen)
136-
ts = ncsMemMap['timestamp'][recn]
137-
nValidSamps = ncsMemMap['nb_valid'][recn]
138-
if ts != predTime:
162+
startBlockPredTime, blk_len)
163+
nValidSamps = nb_valid
164+
if timestamp != predTime:
139165
curBlock.endRec = recn - 1
140166
curBlock.endTime = predTime
141-
curBlock = NcsSection(recn, ts, -1, -1)
167+
curBlock.n_samples = blk_len
168+
curBlock = NcsSection(recn, timestamp, -1, -1, -1)
142169
ncsSects.sects.append(curBlock)
143170
startBlockPredTime = NcsSectionsFactory.calc_sample_time(
144-
ncsSects.sampFreqUsed, ts, nValidSamps)
145-
blkLen = 0
171+
ncsSects.sampFreqUsed,
172+
timestamp,
173+
nValidSamps)
174+
blk_len = 0
146175
else:
147-
blkLen += nValidSamps
176+
blk_len += nValidSamps
148177

149178
curBlock.endRec = ncsMemMap.shape[0] - 1
150179
endTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed,
151180
startBlockPredTime,
152-
blkLen)
181+
blk_len)
153182
curBlock.endTime = endTime
154183

155184
return ncsSects
@@ -199,15 +228,16 @@ def _buildGivenActualFrequency(ncsMemMap, actualSampFreq, reqFreq):
199228
ncsMemMap['sample_rate'][lastBlkI] == reqFreq and \
200229
lts == predLastBlockStartTime:
201230
lastBlkEndTime = NcsSectionsFactory.calc_sample_time(actualSampFreq, lts, lnb)
202-
curBlock = NcsSection(0, ts0, lastBlkI, lastBlkEndTime)
231+
n_samples = NcsSection._RECORD_SIZE * lastBlkI
232+
curBlock = NcsSection(0, ts0, lastBlkI, lastBlkEndTime, n_samples)
203233

204234
nb.sects.append(curBlock)
205235
return nb
206236

207237
# otherwise need to scan looking for breaks
208238
else:
209239
blkOnePredTime = NcsSectionsFactory.calc_sample_time(actualSampFreq, ts0, nb0)
210-
curBlock = NcsSection(0, ts0, -1, -1)
240+
curBlock = NcsSection(0, ts0, -1, -1, -1)
211241
nb.sects.append(curBlock)
212242
return NcsSectionsFactory._parseGivenActualFrequency(ncsMemMap, nb, chanNum, reqFreq,
213243
blkOnePredTime)
@@ -233,60 +263,72 @@ def _parseForMaxGap(ncsMemMap, ncsSects, maxGapLen):
233263
largest block
234264
"""
235265

236-
# track frequency of each block and use estimate with longest block
237-
maxBlkLen = 0
238-
maxBlkFreqEstimate = 0
239-
240-
# Parse the record sequence, finding blocks of continuous time with no more than
241-
# maxGapLength and same channel number
242266
chanNum = ncsMemMap['channel_id'][0]
243-
244-
startBlockTime = ncsMemMap['timestamp'][0]
245-
blkLen = ncsMemMap['nb_valid'][0]
246-
lastRecTime = startBlockTime
247-
lastRecNumSamps = blkLen
248267
recFreq = ncsMemMap['sample_rate'][0]
249268

250-
curBlock = NcsSection(0, startBlockTime, -1, -1)
251-
ncsSects.sects.append(curBlock)
252-
for recn in range(1, ncsMemMap.shape[0]):
253-
if ncsMemMap['channel_id'][recn] != chanNum or \
254-
ncsMemMap['sample_rate'][recn] != recFreq:
255-
raise IOError('Channel number or sampling frequency changed in ' +
256-
'records within file')
257-
predTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, lastRecTime,
258-
lastRecNumSamps)
259-
ts = ncsMemMap['timestamp'][recn]
260-
nb = ncsMemMap['nb_valid'][recn]
261-
if abs(ts - predTime) > maxGapLen:
262-
curBlock.endRec = recn - 1
263-
curBlock.endTime = predTime
264-
curBlock = NcsSection(recn, ts, -1, -1)
265-
ncsSects.sects.append(curBlock)
266-
if blkLen > maxBlkLen:
267-
maxBlkLen = blkLen
268-
maxBlkFreqEstimate = (blkLen - lastRecNumSamps) * 1e6 / \
269-
(lastRecTime - startBlockTime)
270-
startBlockTime = ts
271-
blkLen = nb
272-
else:
273-
blkLen += nb
274-
lastRecTime = ts
275-
lastRecNumSamps = nb
276-
277-
if blkLen > maxBlkLen:
278-
maxBlkFreqEstimate = (blkLen - lastRecNumSamps) * 1e6 / \
279-
(lastRecTime - startBlockTime)
280-
281-
curBlock.endRec = ncsMemMap.shape[0] - 1
282-
endTime = NcsSectionsFactory.calc_sample_time(ncsSects.sampFreqUsed, lastRecTime,
283-
lastRecNumSamps)
284-
curBlock.endTime = endTime
269+
# check for consistent channel_ids and sampling rates
270+
ncsMemMap['channel_id']
271+
if not (ncsMemMap['channel_id'] == chanNum).all():
272+
raise IOError('Channel number changed in records within file')
273+
274+
if not all(ncsMemMap['sample_rate'] == recFreq):
275+
raise IOError('Sampling frequency changed in records within file')
276+
277+
# find most frequent number of samples
278+
exp_nb_valid = np.argmax(np.bincount(ncsMemMap['nb_valid']))
279+
# detect records with incomplete number of samples
280+
gap_rec_ids = list(np.where(ncsMemMap['nb_valid'] != exp_nb_valid)[0])
281+
282+
rec_duration = 1e6 / ncsSects.sampFreqUsed * ncsMemMap['nb_valid']
283+
pred_times = np.rint(ncsMemMap['timestamp'] + rec_duration).astype(np.int64)
284+
max_pred_times = pred_times + maxGapLen
285+
# data records that start later than the predicted time (including the
286+
# maximal accepted gap length) are considered delayed and a gap is
287+
# registered.
288+
delayed_recs = list(np.where(max_pred_times[:-1] < ncsMemMap['timestamp'][1:])[0])
289+
gap_rec_ids.extend(delayed_recs)
290+
291+
# cleaning extracted gap ids
292+
# last record can not be the beginning of a gap
293+
last_rec_id = len(ncsMemMap['timestamp']) - 1
294+
if last_rec_id in gap_rec_ids:
295+
gap_rec_ids.remove(last_rec_id)
296+
297+
# gap ids can only be listed once
298+
gap_rec_ids = sorted(set(gap_rec_ids))
299+
300+
# create recording segments from identified gaps
301+
ncsSects.sects.append(NcsSection(0, ncsMemMap['timestamp'][0], -1, -1, -1))
302+
for gap_rec_id in gap_rec_ids:
303+
curr_sec = ncsSects.sects[-1]
304+
curr_sec.endRec = gap_rec_id
305+
curr_sec.endTime = pred_times[gap_rec_id]
306+
n_samples = np.sum(ncsMemMap['nb_valid'][curr_sec.startRec:gap_rec_id + 1])
307+
curr_sec.n_samples = n_samples
308+
309+
next_sec = NcsSection(gap_rec_id + 1,
310+
ncsMemMap['timestamp'][gap_rec_id + 1], -1, -1, -1)
311+
ncsSects.sects.append(next_sec)
312+
313+
curr_sec = ncsSects.sects[-1]
314+
curr_sec.endRec = len(ncsMemMap['timestamp']) - 1
315+
curr_sec.endTime = pred_times[-1]
316+
n_samples = np.sum(ncsMemMap['nb_valid'][curr_sec.startRec:])
317+
curr_sec.n_samples = n_samples
318+
319+
# calculate the estimated frequency of the block with the most samples
320+
max_blk_idx = np.argmax([bl.endRec - bl.startRec for bl in ncsSects.sects])
321+
max_blk = ncsSects.sects[max_blk_idx]
322+
323+
maxBlkFreqEstimate = (max_blk.n_samples - ncsMemMap['nb_valid'][max_blk.endRec]) * 1e6 / \
324+
(ncsMemMap['timestamp'][max_blk.endRec] - max_blk.startTime)
285325

286326
ncsSects.sampFreqUsed = maxBlkFreqEstimate
287327
ncsSects.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq(
288328
maxBlkFreqEstimate)
289-
329+
# free memory that is unnecessarily occupied by the memmap
330+
# (see https://github.com/numpy/numpy/issues/19340)
331+
del ncsMemMap
290332
return ncsSects
291333

292334
@staticmethod
@@ -325,7 +367,7 @@ def _buildForMaxGap(ncsMemMap, nomFreq):
325367
freqInFile = math.floor(nomFreq)
326368
if lts - predLastBlockStartTime == 0 and lcid == chanNum and lsr == freqInFile:
327369
endTime = NcsSectionsFactory.calc_sample_time(nomFreq, lts, lnb)
328-
curBlock = NcsSection(0, ts0, lastBlkI, endTime)
370+
curBlock = NcsSection(0, ts0, lastBlkI, endTime, numSampsForPred)
329371
nb.sects.append(curBlock)
330372
nb.sampFreqUsed = numSampsForPred / (lts - ts0) * 1e6
331373
nb.microsPerSampUsed = NcsSectionsFactory.get_micros_per_samp_for_freq(nb.sampFreqUsed)

neo/rawio/neuralynxrawio/neuralynxrawio.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747

4848
import numpy as np
4949
import os
50+
import pathlib
51+
import copy
5052
from collections import (namedtuple, OrderedDict)
5153

5254
from neo.rawio.neuralynxrawio.ncssections import (NcsSection, NcsSectionsFactory)
@@ -110,6 +112,9 @@ def _source_name(self):
110112
else:
111113
return self.dirname
112114

115+
# from memory_profiler import profile
116+
#
117+
# @profile()
113118
def _parse_header(self):
114119

115120
stream_channels = []
@@ -139,6 +144,11 @@ def _parse_header(self):
139144
filenames = sorted(os.listdir(self.dirname))
140145
dirname = self.dirname
141146
else:
147+
if not os.path.isfile(self.filename):
148+
raise ValueError(f'Provided Filename is not a file: '
149+
f'{self.filename}. If you want to provide a '
150+
f'directory use the `dirname` keyword')
151+
142152
dirname, fname = os.path.split(self.filename)
143153
filenames = [fname]
144154

@@ -209,15 +219,7 @@ def _parse_header(self):
209219
'Several nse or ntt files have the same unit_id!!!'
210220
self.nse_ntt_filenames[chan_uid] = filename
211221

212-
dtype = get_nse_or_ntt_dtype(info, ext)
213-
214-
if os.path.getsize(filename) <= NlxHeader.HEADER_SIZE:
215-
self._empty_nse_ntt.append(filename)
216-
data = np.zeros((0,), dtype=dtype)
217-
else:
218-
data = np.memmap(filename, dtype=dtype, mode='r',
219-
offset=NlxHeader.HEADER_SIZE)
220-
222+
data = self._get_file_map(filename)
221223
self._spike_memmap[chan_uid] = data
222224

223225
unit_ids = np.unique(data['unit_id'])
@@ -249,8 +251,7 @@ def _parse_header(self):
249251
data = np.zeros((0,), dtype=nev_dtype)
250252
internal_ids = []
251253
else:
252-
data = np.memmap(filename, dtype=nev_dtype, mode='r',
253-
offset=NlxHeader.HEADER_SIZE)
254+
data = self._get_file_map(filename)
254255
internal_ids = np.unique(data[['event_id', 'ttl_input']]).tolist()
255256
for internal_event_id in internal_ids:
256257
if internal_event_id not in self.internal_event_ids:
@@ -378,6 +379,37 @@ def _parse_header(self):
378379
# ~ ev_ann['digital_marker'] =
379380
# ~ ev_ann['analog_marker'] =
380381

382+
def _get_file_map(self, filename):
383+
"""
384+
Create memory maps when needed
385+
see also https://github.com/numpy/numpy/issues/19340
386+
"""
387+
filename = pathlib.Path(filename)
388+
suffix = filename.suffix.lower()[1:]
389+
390+
if suffix == 'ncs':
391+
return np.memmap(filename, dtype=self._ncs_dtype, mode='r',
392+
offset=NlxHeader.HEADER_SIZE)
393+
394+
elif suffix in ['nse', 'ntt']:
395+
info = NlxHeader(filename)
396+
dtype = get_nse_or_ntt_dtype(info, suffix)
397+
398+
# return empty map if file does not contain data
399+
if os.path.getsize(filename) <= NlxHeader.HEADER_SIZE:
400+
self._empty_nse_ntt.append(filename)
401+
return np.zeros((0,), dtype=dtype)
402+
403+
return np.memmap(filename, dtype=dtype, mode='r',
404+
offset=NlxHeader.HEADER_SIZE)
405+
406+
elif suffix == 'nev':
407+
return np.memmap(filename, dtype=nev_dtype, mode='r',
408+
offset=NlxHeader.HEADER_SIZE)
409+
410+
else:
411+
raise ValueError(f'Unknown file suffix {suffix}')
412+
381413
# Accessors for segment times which are offset by appropriate global start time
382414
def _segment_t_start(self, block_index, seg_index):
383415
return self._seg_t_starts[seg_index] - self.global_t_start
@@ -565,16 +597,15 @@ def scan_ncs_files(self, ncs_filenames):
565597
chanSectMap = dict()
566598
for chan_uid, ncs_filename in self.ncs_filenames.items():
567599

568-
data = np.memmap(ncs_filename, dtype=self._ncs_dtype, mode='r',
569-
offset=NlxHeader.HEADER_SIZE)
600+
data = self._get_file_map(ncs_filename)
570601
nlxHeader = NlxHeader(ncs_filename)
571602

572603
if not chanSectMap or (chanSectMap and
573604
not NcsSectionsFactory._verifySectionsStructure(data,
574605
lastNcsSections)):
575606
lastNcsSections = NcsSectionsFactory.build_for_ncs_file(data, nlxHeader)
576-
577-
chanSectMap[chan_uid] = [lastNcsSections, nlxHeader, data]
607+
chanSectMap[chan_uid] = [lastNcsSections, nlxHeader, ncs_filename]
608+
del data
578609

579610
# Construct an inverse dictionary from NcsSections to list of associated chan_uids
580611
revSectMap = dict()
@@ -584,8 +615,8 @@ def scan_ncs_files(self, ncs_filenames):
584615
# If there is only one NcsSections structure in the set of ncs files, there should only
585616
# be one entry. Otherwise this is presently unsupported.
586617
if len(revSectMap) > 1:
587-
raise IOError('ncs files have {} different sections structures. Unsupported.'.format(
588-
len(revSectMap)))
618+
raise IOError(f'ncs files have {len(revSectMap)} different sections '
619+
f'structures. Unsupported configuration.')
589620

590621
seg_time_limits = SegmentTimeLimits(nb_segment=len(lastNcsSections.sects),
591622
t_start=[], t_stop=[], length=[],
@@ -595,7 +626,7 @@ def scan_ncs_files(self, ncs_filenames):
595626
# create segment with subdata block/t_start/t_stop/length for each channel
596627
for i, fileEntry in enumerate(self.ncs_filenames.items()):
597628
chan_uid = fileEntry[0]
598-
data = chanSectMap[chan_uid][2]
629+
data = self._get_file_map(chanSectMap[chan_uid][2])
599630

600631
# create a memmap for each record section of the current file
601632
curSects = chanSectMap[chan_uid][0]

0 commit comments

Comments
 (0)