Skip to content

Commit ded9a66

Browse files
committed
* optimize imports
* add tqdm to longer steps * a bit of black-like formatting * add tqdm as a requirement for plexon
1 parent a36df30 commit ded9a66

File tree

2 files changed

+64
-45
lines changed

2 files changed

+64
-45
lines changed

neo/rawio/plexonrawio.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
Author: Samuel Garcia
2222
2323
"""
24-
25-
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
26-
_spike_channel_dtype, _event_channel_dtype)
24+
import datetime
25+
from collections import OrderedDict
2726

2827
import numpy as np
29-
from collections import OrderedDict
30-
import datetime
28+
from tqdm import tqdm, trange
29+
30+
from .baserawio import (
31+
BaseRawIO, _signal_channel_dtype, _signal_stream_dtype, _spike_channel_dtype, _event_channel_dtype
32+
)
3133

3234

3335
class PlexonRawIO(BaseRawIO):
@@ -45,43 +47,52 @@ def _parse_header(self):
4547

4648
# global header
4749
with open(self.filename, 'rb') as fid:
48-
offset0 = 0
49-
global_header = read_as_dict(fid, GlobalHeader, offset=offset0)
50+
global_header = read_as_dict(fid, GlobalHeader)
5051

51-
rec_datetime = datetime.datetime(global_header['Year'],
52-
global_header['Month'],
53-
global_header['Day'],
54-
global_header['Hour'],
55-
global_header['Minute'],
56-
global_header['Second'])
52+
rec_datetime = datetime.datetime(
53+
global_header['Year'],
54+
global_header['Month'],
55+
global_header['Day'],
56+
global_header['Hour'],
57+
global_header['Minute'],
58+
global_header['Second'],
59+
)
5760

5861
# dsp channels header = spikes and waveforms
5962
nb_unit_chan = global_header['NumDSPChannels']
6063
offset1 = np.dtype(GlobalHeader).itemsize
61-
dspChannelHeaders = np.memmap(self.filename, dtype=DspChannelHeader, mode='r',
62-
offset=offset1, shape=(nb_unit_chan,))
64+
dspChannelHeaders = np.memmap(
65+
self.filename, dtype=DspChannelHeader, mode='r', offset=offset1, shape=(nb_unit_chan,)
66+
)
6367

6468
# event channel header
6569
nb_event_chan = global_header['NumEventChannels']
6670
offset2 = offset1 + np.dtype(DspChannelHeader).itemsize * nb_unit_chan
67-
eventHeaders = np.memmap(self.filename, dtype=EventChannelHeader, mode='r',
68-
offset=offset2, shape=(nb_event_chan,))
71+
eventHeaders = np.memmap(
72+
self.filename, dtype=EventChannelHeader, mode='r', offset=offset2, shape=(nb_event_chan,)
73+
)
6974

7075
# slow channel header = signal
7176
nb_sig_chan = global_header['NumSlowChannels']
7277
offset3 = offset2 + np.dtype(EventChannelHeader).itemsize * nb_event_chan
73-
slowChannelHeaders = np.memmap(self.filename, dtype=SlowChannelHeader, mode='r',
74-
offset=offset3, shape=(nb_sig_chan,))
78+
slowChannelHeaders = np.memmap(
79+
self.filename, dtype=SlowChannelHeader, mode='r', offset=offset3, shape=(nb_sig_chan,)
80+
)
7581

7682
offset4 = offset3 + np.dtype(SlowChannelHeader).itemsize * nb_sig_chan
7783

7884
# locate data blocks and group them by type and channel
79-
block_pos = {1: {c: [] for c in dspChannelHeaders['Channel']},
80-
4: {c: [] for c in eventHeaders['Channel']},
81-
5: {c: [] for c in slowChannelHeaders['Channel']},
82-
}
85+
block_pos = {
86+
1: {c: [] for c in dspChannelHeaders['Channel']},
87+
4: {c: [] for c in eventHeaders['Channel']},
88+
5: {c: [] for c in slowChannelHeaders['Channel']},
89+
}
8390
data = self._memmap = np.memmap(self.filename, dtype='u1', offset=0, mode='r')
8491
pos = offset4
92+
93+
# Create a tqdm object with a total of len(data) and an initial value of 0 for offset
94+
progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True)
95+
8596
while pos < data.size:
8697
bl_header = data[pos:pos + 16].view(DataBlockHeader)[0]
8798
length = bl_header['NumberOfWaveforms'] * bl_header['NumberOfWordsInWaveform'] * 2 + 16
@@ -90,6 +101,11 @@ def _parse_header(self):
90101
block_pos[bl_type][chan_id].append(pos)
91102
pos += length
92103

104+
# Update tqdm with the number of bytes processed in this iteration
105+
progress_bar.update(length)
106+
107+
progress_bar.close()
108+
93109
self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
94110
2 ** 32 + bl_header['TimeStamp']
95111

@@ -105,9 +121,9 @@ def _parse_header(self):
105121
# Signals
106122
5: np.dtype(dt_base + [('cumsum', 'int64'), ]),
107123
}
108-
for bl_type in block_pos:
124+
for bl_type in tqdm(block_pos, desc="Finalizing data blocks", leave=True):
109125
self._data_blocks[bl_type] = {}
110-
for chan_id in block_pos[bl_type]:
126+
for chan_id in tqdm(block_pos[bl_type], desc="Finalizing data blocks for type %d" % bl_type, leave=True):
111127
positions = block_pos[bl_type][chan_id]
112128
dt = dtype_by_bltype[bl_type]
113129
data_block = np.empty((len(positions)), dtype=dt)
@@ -132,7 +148,7 @@ def _parse_header(self):
132148
data_block['label'][index] = bl_header['Unit']
133149
elif bl_type == 5: # Signals
134150
if data_block.size > 0:
135-
# cumulative some of sample index for fast access to chunks
151+
# cumulative sum of sample index for fast access to chunks
136152
if index == 0:
137153
data_block['cumsum'][index] = 0
138154
else:
@@ -143,7 +159,7 @@ def _parse_header(self):
143159
# signals channels
144160
sig_channels = []
145161
all_sig_length = []
146-
for chan_index in range(nb_sig_chan):
162+
for chan_index in trange(nb_sig_chan, desc="Parsing signal channels", leave=True):
147163
h = slowChannelHeaders[chan_index]
148164
name = h['Name'].decode('utf8')
149165
chan_id = h['Channel']
@@ -164,8 +180,9 @@ def _parse_header(self):
164180
h['Gain'] * h['PreampGain'])
165181
offset = 0.
166182
stream_id = '0'
167-
sig_channels.append((name, str(chan_id), sampling_rate, sig_dtype,
168-
units, gain, offset, stream_id))
183+
sig_channels.append(
184+
(name, str(chan_id), sampling_rate, sig_dtype, units, gain, offset, stream_id)
185+
)
169186

170187
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
171188

@@ -203,7 +220,7 @@ def _parse_header(self):
203220

204221
# Spikes channels
205222
spike_channels = []
206-
for unit_index, (chan_id, unit_id) in enumerate(self.internal_unit_ids):
223+
for unit_index, (chan_id, unit_id) in tqdm(enumerate(self.internal_unit_ids), desc="Parsing spike channels", leave=True):
207224
c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0]
208225
h = dspChannelHeaders[c]
209226

@@ -223,28 +240,29 @@ def _parse_header(self):
223240
wf_offset = 0.
224241
wf_left_sweep = -1 # DONT KNOWN
225242
wf_sampling_rate = global_header['WaveformFreq']
226-
spike_channels.append((name, _id, wf_units, wf_gain, wf_offset,
227-
wf_left_sweep, wf_sampling_rate))
243+
spike_channels.append(
244+
(name, _id, wf_units, wf_gain, wf_offset, wf_left_sweep, wf_sampling_rate)
245+
)
228246
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
229247

230248
# Event channels
231249
event_channels = []
232-
for chan_index in range(nb_event_chan):
250+
for chan_index in trange(nb_event_chan, desc="Parsing event channels", leave=True):
233251
h = eventHeaders[chan_index]
234252
chan_id = h['Channel']
235253
name = h['Name'].decode('utf8')
236-
_id = h['Channel']
237-
event_channels.append((name, _id, 'event'))
254+
event_channels.append((name, chan_id, 'event'))
238255
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
239256

240-
# fille into header dict
241-
self.header = {}
242-
self.header['nb_block'] = 1
243-
self.header['nb_segment'] = [1]
244-
self.header['signal_streams'] = signal_streams
245-
self.header['signal_channels'] = sig_channels
246-
self.header['spike_channels'] = spike_channels
247-
self.header['event_channels'] = event_channels
257+
# fill into header dict
258+
self.header = {
259+
"nb_block": 1,
260+
"nb_segment": [1],
261+
"signal_streams": signal_streams,
262+
"signal_channels": sig_channels,
263+
"spike_channels": spike_channels,
264+
"event_channels": event_channels,
265+
}
248266

249267
# Annotations
250268
self._generate_minimal_annotations()
@@ -399,13 +417,13 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
399417
return event_times
400418

401419

402-
def read_as_dict(fid, dtype, offset=None):
420+
def read_as_dict(fid, dtype, offset: int = 0):
403421
"""
404422
Given a file descriptor
405423
and a numpy.dtype of the binary struct return a dict.
406424
Make conversion for strings.
407425
"""
408-
if offset is not None:
426+
if offset:
409427
fid.seek(offset)
410428
dt = np.dtype(dtype)
411429
h = np.frombuffer(fid.read(dt.itemsize), dt)[0]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ nwb = ["pynwb"]
9898
maxwell = ["h5py"]
9999
biocam = ["h5py"]
100100
med = ["dhn_med_py>=1.0.0"]
101+
plexon = ["tqdm"]
101102
plexon2 = ["zugbruecke>=0.2; sys_platform!='win32'", "wenv; sys_platform!='win32'"]
102103

103104
all = [

0 commit comments

Comments
 (0)