Skip to content

Commit ca7e142

Browse files
authored
Merge pull request #1343 from bendichter/acc_plexon
PlexonRawIO style
2 parents 288e101 + 603a89b commit ca7e142

File tree

1 file changed

+117
-44
lines changed

1 file changed

+117
-44
lines changed

neo/rawio/plexonrawio.py

Lines changed: 117 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,43 @@
2121
Author: Samuel Garcia
2222
2323
"""
24-
25-
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
26-
_spike_channel_dtype, _event_channel_dtype)
24+
import datetime
25+
from collections import OrderedDict
2726

2827
import numpy as np
29-
from collections import OrderedDict
30-
import datetime
28+
try:
29+
from tqdm import tqdm, trange
30+
HAVE_TQDM = True
31+
except:
32+
HAVE_TQDM = False
33+
34+
from .baserawio import (
35+
BaseRawIO,
36+
_signal_channel_dtype,
37+
_signal_stream_dtype,
38+
_spike_channel_dtype,
39+
_event_channel_dtype,
40+
)
3141

3242

3343
class PlexonRawIO(BaseRawIO):
3444
extensions = ['plx']
3545
rawmode = 'one-file'
3646

37-
def __init__(self, filename=''):
47+
def __init__(self, filename='', progress_bar=True):
48+
"""
49+
50+
Parameters
51+
----------
52+
filename: str
53+
The filename.
54+
progress_bar: bool, default True
55+
Display progress bar using tqdm (if installed) when parsing the file.
56+
57+
"""
3858
BaseRawIO.__init__(self)
3959
self.filename = filename
60+
self.progress_bar = HAVE_TQDM and progress_bar
4061

4162
def _source_name(self):
4263
return self.filename
@@ -45,43 +66,57 @@ def _parse_header(self):
4566

4667
# global header
4768
with open(self.filename, 'rb') as fid:
48-
offset0 = 0
49-
global_header = read_as_dict(fid, GlobalHeader, offset=offset0)
69+
global_header = read_as_dict(fid, GlobalHeader)
5070

51-
rec_datetime = datetime.datetime(global_header['Year'],
52-
global_header['Month'],
53-
global_header['Day'],
54-
global_header['Hour'],
55-
global_header['Minute'],
56-
global_header['Second'])
71+
rec_datetime = datetime.datetime(
72+
global_header['Year'],
73+
global_header['Month'],
74+
global_header['Day'],
75+
global_header['Hour'],
76+
global_header['Minute'],
77+
global_header['Second'],
78+
)
5779

5880
# dsp channels header = spikes and waveforms
5981
nb_unit_chan = global_header['NumDSPChannels']
6082
offset1 = np.dtype(GlobalHeader).itemsize
61-
dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r',
62-
offset=offset1, shape=(nb_unit_chan,))
83+
dspChannelHeaders = np.memmap(
84+
self.filename, dtype=DspChannelHeader, mode='r', offset=offset1, shape=(nb_unit_chan,)
85+
)
6386

6487
# event channel header
6588
nb_event_chan = global_header['NumEventChannels']
6689
offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan
67-
eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r',
68-
offset=offset2, shape=(nb_event_chan,))
90+
eventHeaders = np.memmap(
91+
self.filename,
92+
dtype=EventChannelHeader,
93+
mode='r',
94+
offset=offset2,
95+
shape=(nb_event_chan,),
96+
)
6997

7098
# slow channel header = signal
7199
nb_sig_chan = global_header['NumSlowChannels']
72100
offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan
73-
slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r',
74-
offset=offset3, shape=(nb_sig_chan,))
101+
slowChannelHeaders = np.memmap(
102+
self.filename, dtype=SlowChannelHeader, mode='r', offset=offset3, shape=(nb_sig_chan,)
103+
)
75104

76105
offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan
77106

78107
# locate data blocks and group them by type and channel
79-
block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']},
80-
4: {c: [] for c in eventHeaders['Channel']},
81-
5: {c: [] for c in slowChannelHeaders['Channel']},
82-
}
108+
block_pos = {
109+
1: {c: [] for c in dspChannelHeaders['Channel']},
110+
4: {c: [] for c in eventHeaders['Channel']},
111+
5: {c: [] for c in slowChannelHeaders['Channel']},
112+
}
83113
data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r')
84114
pos = offset4
115+
116+
# Create a tqdm object with a total of len(data) and an initial value of 0 for offset
117+
if self.progress_bar :
118+
progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True)
119+
85120
while pos < data.size:
86121
bl_header = data[pos:pos + 16].view(DataBlockHeader)[0]
87122
length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16
@@ -90,6 +125,13 @@ def _parse_header(self):
90125
block_pos[bl_type][chan_id].append(pos)
91126
pos += length
92127

128+
# Update tqdm with the number of bytes processed in this iteration
129+
if self.progress_bar :
130+
progress_bar.update(length)
131+
132+
if self.progress_bar :
133+
progress_bar.close()
134+
93135
self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
94136
2 ** 32 + bl_header['TimeStamp']
95137

@@ -105,9 +147,21 @@ def _parse_header(self):
105147
# Signals
106148
5: np.dtype(dt_base + [('cumsum', 'int64'), ]),
107149
}
108-
for bl_type in block_pos:
150+
if self.progress_bar :
151+
bl_loop = tqdm(block_pos, desc="Finalizing data blocks", leave=True)
152+
else:
153+
bl_loop = block_pos
154+
for bl_type in bl_loop:
109155
self._data_blocks[bl_type] = {}
110-
for chan_id in block_pos[bl_type]:
156+
if self.progress_bar :
157+
chan_loop = tqdm(
158+
block_pos[bl_type],
159+
desc="Finalizing data blocks for type %d" % bl_type,
160+
leave=True,
161+
)
162+
else:
163+
chan_loop = block_pos[bl_type]
164+
for chan_id in chan_loop:
111165
positions = block_pos[bl_type][chan_id]
112166
dt = dtype_by_bltype[bl_type]
113167
data_block = np.empty((len(positions)), dtype=dt)
@@ -132,7 +186,7 @@ def _parse_header(self):
132186
data_block['label'][index] = bl_header['Unit']
133187
elif bl_type == 5: # Signals
134188
if data_block.size > 0:
135-
# cumulative some of sample index for fast access to chunks
189+
# cumulative sum of sample index for fast access to chunks
136190
if index == 0:
137191
data_block['cumsum'][index] = 0
138192
else:
@@ -143,7 +197,11 @@ def _parse_header(self):
143197
# signals channels
144198
sig_channels = []
145199
all_sig_length = []
146-
for chan_index in range(nb_sig_chan):
200+
if self.progress_bar:
201+
chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True)
202+
else:
203+
chan_loop = range(nb_sig_chan)
204+
for chan_index in chan_loop:
147205
h = slowChannelHeaders[chan_index]
148206
name = h['Name'].decode('utf8')
149207
chan_id = h['Channel']
@@ -164,8 +222,9 @@ def _parse_header(self):
164222
h['Gain'] * h['PreampGain'])
165223
offset = 0.
166224
stream_id = '0'
167-
sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype,
168-
units, gain, offset, stream_id))
225+
sig_channels.append(
226+
(name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id)
227+
)
169228

170229
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
171230

@@ -203,7 +262,16 @@ def _parse_header(self):
203262

204263
# Spikes channels
205264
spike_channels = []
206-
for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids):
265+
if self.progress_bar:
266+
unit_loop = tqdm(
267+
enumerate(self.internal_unit_ids),
268+
desc="Parsing spike channels",
269+
leave=True,
270+
)
271+
else:
272+
unit_loop = enumerate(self.internal_unit_ids)
273+
274+
for unit_index, (chan_id, unit_id) in unit_loop:
207275
c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0]
208276
h = dspChannelHeaders[c]
209277

@@ -223,28 +291,33 @@ def _parse_header(self):
223291
wf_offset = 0.
224292
wf_left_sweep = -1 # DONT KNOWN
225293
wf_sampling_rate = global_header['WaveformFreq']
226-
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset,
227-
wf_left_sweep, wf_sampling_rate))
294+
spike_channels.append(
295+
(name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate)
296+
)
228297
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
229298

230299
# Event channels
231300
event_channels = []
232-
for chan_index in range(nb_event_chan):
301+
if self.progress_bar:
302+
chan_loop = trange(nb_event_chan, desc="Parsing event channels", leave=True)
303+
else:
304+
chan_loop = range(nb_event)
305+
for chan_index in chan_loop:
233306
h = eventHeaders[chan_index]
234307
chan_id = h['Channel']
235308
name = h['Name'].decode('utf8')
236-
_id = h['Channel']
237-
event_channels.append((name, _id, 'event'))
309+
event_channels.append((name, chan_id, 'event'))
238310
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
239311

240-
# fille into header dict
241-
self.header = {}
242-
self.header['nb_block'] = 1
243-
self.header['nb_segment'] = [1]
244-
self.header['signal_streams'] = signal_streams
245-
self.header['signal_channels'] = sig_channels
246-
self.header['spike_channels'] = spike_channels
247-
self.header['event_channels'] = event_channels
312+
# fill into header dict
313+
self.header = {
314+
"nb_block": 1,
315+
"nb_segment": [1],
316+
"signal_streams": signal_streams,
317+
"signal_channels": sig_channels,
318+
"spike_channels": spike_channels,
319+
"event_channels": event_channels,
320+
}
248321

249322
# Annotations
250323
self._generate_minimal_annotations()

0 commit comments

Comments
 (0)