2121Author: Samuel Garcia
2222
2323"""
24-
25- from .baserawio import (BaseRawIO , _signal_channel_dtype , _signal_stream_dtype ,
26- _spike_channel_dtype , _event_channel_dtype )
24+ import datetime
25+ from collections import OrderedDict
2726
2827import numpy as np
29- from collections import OrderedDict
30- import datetime
28+ try :
29+ from tqdm import tqdm , trange
30+ HAVE_TQDM = True
31+ except :
32+ HAVE_TQDM = False
33+
34+ from .baserawio import (
35+ BaseRawIO ,
36+ _signal_channel_dtype ,
37+ _signal_stream_dtype ,
38+ _spike_channel_dtype ,
39+ _event_channel_dtype ,
40+ )
3141
3242
3343class PlexonRawIO (BaseRawIO ):
3444 extensions = ['plx' ]
3545 rawmode = 'one-file'
3646
37- def __init__ (self , filename = '' ):
47+ def __init__ (self , filename = '' , progress_bar = True ):
48+ """
49+
50+ Parameters
51+ ----------
52+ filename: str
53+ The filename.
54+ progress_bar: bool, default True
55+ Display progress bar using tqdm (if installed) when parsing the file.
56+
57+ """
3858 BaseRawIO .__init__ (self )
3959 self .filename = filename
60+ self .progress_bar = HAVE_TQDM and progress_bar
4061
4162 def _source_name (self ):
4263 return self .filename
@@ -45,43 +66,57 @@ def _parse_header(self):
4566
4667 # global header
4768 with open (self .filename , 'rb' ) as fid :
48- offset0 = 0
49- global_header = read_as_dict (fid , GlobalHeader , offset = offset0 )
69+ global_header = read_as_dict (fid , GlobalHeader )
5070
51- rec_datetime = datetime .datetime (global_header ['Year' ],
52- global_header ['Month' ],
53- global_header ['Day' ],
54- global_header ['Hour' ],
55- global_header ['Minute' ],
56- global_header ['Second' ])
71+ rec_datetime = datetime .datetime (
72+ global_header ['Year' ],
73+ global_header ['Month' ],
74+ global_header ['Day' ],
75+ global_header ['Hour' ],
76+ global_header ['Minute' ],
77+ global_header ['Second' ],
78+ )
5779
5880 # dsp channels header = spikes and waveforms
5981 nb_unit_chan = global_header ['NumDSPChannels' ]
6082 offset1 = np .dtype (GlobalHeader ).itemsize
61- dspChannelHeaders = np .memmap (self .filename , dtype = DspChannelHeader , mode = 'r' ,
62- offset = offset1 , shape = (nb_unit_chan ,))
83+ dspChannelHeaders = np .memmap (
84+ self .filename , dtype = DspChannelHeader , mode = 'r' , offset = offset1 , shape = (nb_unit_chan ,)
85+ )
6386
6487 # event channel header
6588 nb_event_chan = global_header ['NumEventChannels' ]
6689 offset2 = offset1 + np .dtype (DspChannelHeader ).itemsize * nb_unit_chan
67- eventHeaders = np .memmap (self .filename , dtype = EventChannelHeader , mode = 'r' ,
68- offset = offset2 , shape = (nb_event_chan ,))
90+ eventHeaders = np .memmap (
91+ self .filename ,
92+ dtype = EventChannelHeader ,
93+ mode = 'r' ,
94+ offset = offset2 ,
95+ shape = (nb_event_chan ,),
96+ )
6997
7098 # slow channel header = signal
7199 nb_sig_chan = global_header ['NumSlowChannels' ]
72100 offset3 = offset2 + np .dtype (EventChannelHeader ).itemsize * nb_event_chan
73- slowChannelHeaders = np .memmap (self .filename , dtype = SlowChannelHeader , mode = 'r' ,
74- offset = offset3 , shape = (nb_sig_chan ,))
101+ slowChannelHeaders = np .memmap (
102+ self .filename , dtype = SlowChannelHeader , mode = 'r' , offset = offset3 , shape = (nb_sig_chan ,)
103+ )
75104
76105 offset4 = offset3 + np .dtype (SlowChannelHeader ).itemsize * nb_sig_chan
77106
78107 # locate data blocks and group them by type and channel
79- block_pos = {1 : {c : [] for c in dspChannelHeaders ['Channel' ]},
80- 4 : {c : [] for c in eventHeaders ['Channel' ]},
81- 5 : {c : [] for c in slowChannelHeaders ['Channel' ]},
82- }
108+ block_pos = {
109+ 1 : {c : [] for c in dspChannelHeaders ['Channel' ]},
110+ 4 : {c : [] for c in eventHeaders ['Channel' ]},
111+ 5 : {c : [] for c in slowChannelHeaders ['Channel' ]},
112+ }
83113 data = self ._memmap = np .memmap (self .filename , dtype = 'u1' , offset = 0 , mode = 'r' )
84114 pos = offset4
115+
116+ # Create a tqdm object with a total of len(data) and an initial value of 0 for offset
117+ if self .progress_bar :
118+ progress_bar = tqdm (total = len (data ), initial = 0 , desc = "Parsing data blocks" , leave = True )
119+
85120 while pos < data .size :
86121 bl_header = data [pos :pos + 16 ].view (DataBlockHeader )[0 ]
87122 length = bl_header ['NumberOfWaveforms' ] * bl_header ['NumberOfWordsInWaveform' ] * 2 + 16
@@ -90,6 +125,13 @@ def _parse_header(self):
90125 block_pos [bl_type ][chan_id ].append (pos )
91126 pos += length
92127
128+ # Update tqdm with the number of bytes processed in this iteration
129+ if self .progress_bar :
130+ progress_bar .update (length )
131+
132+ if self .progress_bar :
133+ progress_bar .close ()
134+
93135 self ._last_timestamps = bl_header ['UpperByteOf5ByteTimestamp' ] * \
94136 2 ** 32 + bl_header ['TimeStamp' ]
95137
@@ -105,9 +147,21 @@ def _parse_header(self):
105147 # Signals
106148 5 : np .dtype (dt_base + [('cumsum' , 'int64' ), ]),
107149 }
108- for bl_type in block_pos :
150+ if self .progress_bar :
151+ bl_loop = tqdm (block_pos , desc = "Finalizing data blocks" , leave = True )
152+ else :
153+ bl_loop = block_pos
154+ for bl_type in bl_loop :
109155 self ._data_blocks [bl_type ] = {}
110- for chan_id in block_pos [bl_type ]:
156+ if self .progress_bar :
157+ chan_loop = tqdm (
158+ block_pos [bl_type ],
159+ desc = "Finalizing data blocks for type %d" % bl_type ,
160+ leave = True ,
161+ )
162+ else :
163+ chan_loop = block_pos [bl_type ]
164+ for chan_id in chan_loop :
111165 positions = block_pos [bl_type ][chan_id ]
112166 dt = dtype_by_bltype [bl_type ]
113167 data_block = np .empty ((len (positions )), dtype = dt )
@@ -132,7 +186,7 @@ def _parse_header(self):
132186 data_block ['label' ][index ] = bl_header ['Unit' ]
133187 elif bl_type == 5 : # Signals
134188 if data_block .size > 0 :
135- # cumulative some of sample index for fast access to chunks
189+ # cumulative sum of sample index for fast access to chunks
136190 if index == 0 :
137191 data_block ['cumsum' ][index ] = 0
138192 else :
@@ -143,7 +197,11 @@ def _parse_header(self):
143197 # signals channels
144198 sig_channels = []
145199 all_sig_length = []
146- for chan_index in range (nb_sig_chan ):
200+ if self .progress_bar :
201+ chan_loop = trange (nb_sig_chan , desc = "Parsing signal channels" , leave = True )
202+ else :
203+ chan_loop = range (nb_sig_chan )
204+ for chan_index in chan_loop :
147205 h = slowChannelHeaders [chan_index ]
148206 name = h ['Name' ].decode ('utf8' )
149207 chan_id = h ['Channel' ]
@@ -164,8 +222,9 @@ def _parse_header(self):
164222 h ['Gain' ] * h ['PreampGain' ])
165223 offset = 0.
166224 stream_id = '0'
167- sig_channels .append ((name , str (chan_id ), sampling_rate , sig_dtype ,
168- units , gain , offset , stream_id ))
225+ sig_channels .append (
226+ (name , str (chan_id ), sampling_rate , sig_dtype , units , gain , offset , stream_id )
227+ )
169228
170229 sig_channels = np .array (sig_channels , dtype = _signal_channel_dtype )
171230
@@ -203,7 +262,16 @@ def _parse_header(self):
203262
204263 # Spikes channels
205264 spike_channels = []
206- for unit_index , (chan_id , unit_id ) in enumerate (self .internal_unit_ids ):
265+ if self .progress_bar :
266+ unit_loop = tqdm (
267+ enumerate (self .internal_unit_ids ),
268+ desc = "Parsing spike channels" ,
269+ leave = True ,
270+ )
271+ else :
272+ unit_loop = enumerate (self .internal_unit_ids )
273+
274+ for unit_index , (chan_id , unit_id ) in unit_loop :
207275 c = np .nonzero (dspChannelHeaders ['Channel' ] == chan_id )[0 ][0 ]
208276 h = dspChannelHeaders [c ]
209277
@@ -223,28 +291,33 @@ def _parse_header(self):
223291 wf_offset = 0.
224292 wf_left_sweep = - 1 # DONT KNOWN
225293 wf_sampling_rate = global_header ['WaveformFreq' ]
226- spike_channels .append ((name , _id , wf_units , wf_gain , wf_offset ,
227- wf_left_sweep , wf_sampling_rate ))
294+ spike_channels .append (
295+ (name , _id , wf_units , wf_gain , wf_offset , wf_left_sweep , wf_sampling_rate )
296+ )
228297 spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
229298
230299 # Event channels
231300 event_channels = []
232- for chan_index in range (nb_event_chan ):
301+ if self .progress_bar :
302+ chan_loop = trange (nb_event_chan , desc = "Parsing event channels" , leave = True )
303+ else :
304+ chan_loop = range (nb_event )
305+ for chan_index in chan_loop :
233306 h = eventHeaders [chan_index ]
234307 chan_id = h ['Channel' ]
235308 name = h ['Name' ].decode ('utf8' )
236- _id = h ['Channel' ]
237- event_channels .append ((name , _id , 'event' ))
309+ event_channels .append ((name , chan_id , 'event' ))
238310 event_channels = np .array (event_channels , dtype = _event_channel_dtype )
239311
240- # fille into header dict
241- self .header = {}
242- self .header ['nb_block' ] = 1
243- self .header ['nb_segment' ] = [1 ]
244- self .header ['signal_streams' ] = signal_streams
245- self .header ['signal_channels' ] = sig_channels
246- self .header ['spike_channels' ] = spike_channels
247- self .header ['event_channels' ] = event_channels
312+ # fill into header dict
313+ self .header = {
314+ "nb_block" : 1 ,
315+ "nb_segment" : [1 ],
316+ "signal_streams" : signal_streams ,
317+ "signal_channels" : sig_channels ,
318+ "spike_channels" : spike_channels ,
319+ "event_channels" : event_channels ,
320+ }
248321
249322 # Annotations
250323 self ._generate_minimal_annotations ()
0 commit comments