Skip to content

Commit 22321d3

Browse files
committed
Make tqdm dependency optional for PlexonRawIO and add a option to show or not the progressbar
1 parent 8c747f4 commit 22321d3

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

neo/rawio/plexonrawio.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from collections import OrderedDict
2626

2727
import numpy as np
28-
from tqdm import tqdm, trange
28+
try:
29+
from tqdm import tqdm, trange
30+
HAVE_TQDM = True
31+
except:
32+
HAVE_TQDM = False
2933

3034
from .baserawio import (
3135
BaseRawIO,
@@ -40,9 +44,10 @@ class PlexonRawIO(BaseRawIO):
4044
extensions = ['plx']
4145
rawmode = 'one-file'
4246

43-
def __init__(self, filename=''):
47+
def __init__(self, filename='', progress_bar=True):
4448
BaseRawIO.__init__(self)
4549
self.filename = filename
50+
self.progress_bar = HAVE_TQDM and progress_bar
4651

4752
def _source_name(self):
4853
return self.filename
@@ -99,7 +104,8 @@ def _parse_header(self):
99104
pos = offset4
100105

101106
# Create a tqdm object with a total of len(data) and an initial value of 0 for offset
102-
progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True)
107+
if self.progress_bar :
108+
progress_bar = tqdm(total=len(data), initial=0, desc="Parsing data blocks", leave=True)
103109

104110
while pos < data.size:
105111
bl_header = data[pos:pos + 16].view(DataBlockHeader)[0]
@@ -110,9 +116,11 @@ def _parse_header(self):
110116
pos += length
111117

112118
# Update tqdm with the number of bytes processed in this iteration
113-
progress_bar.update(length)
119+
if self.progress_bar :
120+
progress_bar.update(length)
114121

115-
progress_bar.close()
122+
if self.progress_bar :
123+
progress_bar.close()
116124

117125
self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \
118126
2 ** 32 + bl_header['TimeStamp']
@@ -129,13 +137,21 @@ def _parse_header(self):
129137
# Signals
130138
5: np.dtype(dt_base + [('cumsum', 'int64'), ]),
131139
}
132-
for bl_type in tqdm(block_pos, desc="Finalizing data blocks", leave=True):
140+
if self.progress_bar :
141+
bl_loop = tqdm(block_pos, desc="Finalizing data blocks", leave=True)
142+
else:
143+
bl_loop = block_pos
144+
for bl_type in bl_loop:
133145
self._data_blocks[bl_type] = {}
134-
for chan_id in tqdm(
135-
block_pos[bl_type],
136-
desc="Finalizing data blocks for type %d" % bl_type,
137-
leave=True,
138-
):
146+
if self.progress_bar :
147+
chan_loop = tqdm(
148+
block_pos[bl_type],
149+
desc="Finalizing data blocks for type %d" % bl_type,
150+
leave=True,
151+
)
152+
else:
153+
chan_loop = block_pos[bl_type]
154+
for chan_id in chan_loop:
139155
positions = block_pos[bl_type][chan_id]
140156
dt = dtype_by_bltype[bl_type]
141157
data_block = np.empty((len(positions)), dtype=dt)
@@ -171,7 +187,11 @@ def _parse_header(self):
171187
# signals channels
172188
sig_channels = []
173189
all_sig_length = []
174-
for chan_index in trange(nb_sig_chan, desc="Parsing signal channels", leave=True):
190+
if self.progress_bar:
191+
chan_loop = trange(nb_sig_chan, desc="Parsing signal channels", leave=True)
192+
else:
193+
chan_loop = range(nb_sig_chan)
194+
for chan_index in chan_loop:
175195
h = slowChannelHeaders[chan_index]
176196
name = h['Name'].decode('utf8')
177197
chan_id = h['Channel']
@@ -232,11 +252,16 @@ def _parse_header(self):
232252

233253
# Spikes channels
234254
spike_channels = []
235-
for unit_index, (chan_id, unit_id) in tqdm(
236-
enumerate(self.internal_unit_ids),
237-
desc="Parsing spike channels",
238-
leave=True,
239-
):
255+
if self.progress_bar:
256+
unit_loop = tqdm(
257+
enumerate(self.internal_unit_ids),
258+
desc="Parsing spike channels",
259+
leave=True,
260+
)
261+
else:
262+
unit_loop = enumerate(self.internal_unit_ids)
263+
264+
for unit_index, (chan_id, unit_id) in unit_loop:
240265
c = np.nonzero(dspChannelHeaders['Channel'] == chan_id)[0][0]
241266
h = dspChannelHeaders[c]
242267

@@ -433,13 +458,13 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index)
433458
return event_times
434459

435460

436-
def read_as_dict(fid, dtype, offset: int = 0):
461+
def read_as_dict(fid, dtype, offset=None):
437462
"""
438463
Given a file descriptor
439464
and a numpy.dtype of the binary struct return a dict.
440465
Make conversion for strings.
441466
"""
442-
if offset:
467+
if offset is not None:
443468
fid.seek(offset)
444469
dt = np.dtype(dtype)
445470
h = np.frombuffer(fid.read(dt.itemsize), dt)[0]

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ nwb = ["pynwb"]
9898
maxwell = ["h5py"]
9999
biocam = ["h5py"]
100100
med = ["dhn_med_py>=1.0.0"]
101-
plexon = ["tqdm"]
102101
plexon2 = ["zugbruecke>=0.2; sys_platform!='win32'", "wenv; sys_platform!='win32'"]
103102

104103
all = [

0 commit comments

Comments
 (0)