88
99Note :
1010 * this file format have multiple version. news version include the gain for scaling.
11- The actual implementation do not contain this feature because we don't have
12- files to test this. So now the gain is "hardcoded"
11+ The actual implementation do not contain this feature because we don't have
12+ files to test this. So now the gain is "hardcoded" to 1. and so units
13+ is not handled correctly.
1314
1415The file ".rec" have :
1516 * a fist part in text with xml informations
1819Author: Samuel Garcia
1920"""
2021from .baserawio import (BaseRawIO , _signal_channel_dtype , _signal_stream_dtype ,
21- _spike_channel_dtype , _event_channel_dtype )
22+ _spike_channel_dtype , _event_channel_dtype )
2223
2324import numpy as np
2425
2526from xml .etree import ElementTree
2627
28+
2729class SpikeGadgetsRawIO (BaseRawIO ):
2830 extensions = ['rec' ]
2931 rawmode = 'one-file'
3032
3133 def __init__ (self , filename = '' , selected_streams = None ):
3234 """
33-
3435 filename: str
3536 filename ".rec"
36-
3737 selected_streams: None, list, str
3838 sublist of streams to load/expose to API
3939 uselfull for spikeextractor when one stream isneed.
@@ -48,7 +48,6 @@ def _source_name(self):
4848 return self .filename
4949
5050 def _parse_header (self ):
51-
5251 # parse file until "</Configuration>"
5352 header_size = None
5453 with open (self .filename , mode = 'rb' ) as f :
@@ -60,19 +59,19 @@ def _parse_header(self):
6059
6160 if header_size is None :
6261 ValueError ("SpikeGadgets : the xml header do not contain </Configuration>" )
63-
62+
6463 f .seek (0 )
6564 header_txt = f .read (header_size ).decode ('utf8' )
66-
65+
6766 # explore xml header
6867 root = ElementTree .fromstring (header_txt )
6968 gconf = sr = root .find ('GlobalConfiguration' )
7069 hconf = root .find ('HardwareConfiguration' )
7170 sconf = root .find ('SpikeConfiguration' )
72-
71+
7372 self ._sampling_rate = float (hconf .attrib ['samplingRate' ])
7473 num_ephy_channels = int (hconf .attrib ['numChannels' ])
75-
74+
7675 # explore sub stream and count packet size
7776 # first bytes is 0x55
7877 packet_size = 1
@@ -82,83 +81,85 @@ def _parse_header(self):
8281 num_bytes = int (device .attrib ['numBytes' ])
8382 stream_bytes [stream_id ] = packet_size
8483 packet_size += num_bytes
85-
84+
8685 # timesteamps 4 uint32
8786 self ._timestamp_byte = packet_size
8887 packet_size += 4
89-
88+
9089 packet_size += 2 * num_ephy_channels
9190
9291 # read the binary part lazily
9392 raw_memmap = np .memmap (self .filename , mode = 'r' , offset = header_size , dtype = '<u1' )
9493
9594 num_packet = raw_memmap .size // packet_size
96- raw_memmap = raw_memmap [:num_packet * packet_size ]
95+ raw_memmap = raw_memmap [:num_packet * packet_size ]
9796 self ._raw_memmap = raw_memmap .reshape (- 1 , packet_size )
9897
9998 # create signal channels
10099 stream_ids = []
101100 signal_streams = []
102101 signal_channels = []
103-
102+
104103 # walk in xml device and keep only "analog" one
105- self ._mask_channels_bytes = {}
104+ self ._mask_channels_bytes = {}
106105 for device in hconf :
107106 stream_id = device .attrib ['name' ]
108107 for channel in device :
109108
110109 if 'interleavedDataIDByte' in channel .attrib :
111- # TODO deal with "headstageSensor" wich have interleaved
110+ # TODO LATER: deal with "headstageSensor" wich have interleaved
112111 continue
113112
114113 if channel .attrib ['dataType' ] == 'analog' :
115-
114+
116115 if stream_id not in stream_ids :
117116 stream_ids .append (stream_id )
118117 stream_name = stream_id
119118 signal_streams .append ((stream_name , stream_id ))
120119 self ._mask_channels_bytes [stream_id ] = []
121-
120+
122121 name = channel .attrib ['id' ]
123122 chan_id = channel .attrib ['id' ]
124123 dtype = 'int16'
125- units = 'uV' # TODO check where is the info
126- gain = 1. # TODO check where is the info
124+ # TODO LATER : handle gain correctly according the file version
125+ units = ''
126+ gain = 1.
127127 offset = 0.
128128 signal_channels .append ((name , chan_id , self ._sampling_rate , 'int16' ,
129129 units , gain , offset , stream_id ))
130-
130+
131131 num_bytes = stream_bytes [stream_id ] + int (channel .attrib ['startByte' ])
132132 chan_mask = np .zeros (packet_size , dtype = 'bool' )
133133 chan_mask [num_bytes ] = True
134- chan_mask [num_bytes + 1 ] = True
134+ chan_mask [num_bytes + 1 ] = True
135135 self ._mask_channels_bytes [stream_id ].append (chan_mask )
136-
136+
137137 if num_ephy_channels > 0 :
138138 stream_id = 'trodes'
139139 stream_name = stream_id
140140 signal_streams .append ((stream_name , stream_id ))
141141 self ._mask_channels_bytes [stream_id ] = []
142-
142+
143143 chan_ind = 0
144144 for trode in sconf :
145- for schan in trode :
145+ for schan in trode :
146146 name = 'trode' + trode .attrib ['id' ] + 'chan' + schan .attrib ['hwChan' ]
147147 chan_id = schan .attrib ['hwChan' ]
148- units = 'uV' # TODO check where is the info
149- gain = 1. # TODO check where is the info
148+ # TODO LATER : handle gain correctly according the file version
149+ units = ''
150+ gain = 1.
150151 offset = 0.
151152 signal_channels .append ((name , chan_id , self ._sampling_rate , 'int16' ,
152153 units , gain , offset , stream_id ))
153-
154+
154155 chan_mask = np .zeros (packet_size , dtype = 'bool' )
155156 num_bytes = packet_size - 2 * num_ephy_channels + 2 * chan_ind
156157 chan_mask [num_bytes ] = True
157- chan_mask [num_bytes + 1 ] = True
158+ chan_mask [num_bytes + 1 ] = True
158159 self ._mask_channels_bytes [stream_id ].append (chan_mask )
159-
160+
160161 chan_ind += 1
161-
162+
162163 # make mask as array (used in _get_analogsignal_chunk(...))
163164 self ._mask_streams = {}
164165 for stream_id , l in self ._mask_channels_bytes .items ():
@@ -168,20 +169,19 @@ def _parse_header(self):
168169
169170 signal_streams = np .array (signal_streams , dtype = _signal_stream_dtype )
170171 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
171-
172-
172+
173173 # remove some stream if no wanted
174174 if self .selected_streams is not None :
175175 if isinstance (self .selected_streams , str ):
176176 self .selected_streams = [self .selected_streams ]
177177 assert isinstance (self .selected_streams , list )
178-
178+
179179 keep = np .in1d (signal_streams ['id' ], self .selected_streams )
180180 signal_streams = signal_streams [keep ]
181181
182182 keep = np .in1d (signal_channels ['stream_id' ], self .selected_streams )
183183 signal_channels = signal_channels [keep ]
184-
184+
185185 # No events channels
186186 event_channels = []
187187 event_channels = np .array (event_channels , dtype = _event_channel_dtype )
@@ -202,7 +202,7 @@ def _parse_header(self):
202202 self ._generate_minimal_annotations ()
203203 # info from GlobalConfiguration in xml are copied to block and seg annotations
204204 bl_ann = self .raw_annotations ['blocks' ][0 ]
205- seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][0 ]
205+ seg_ann = self .raw_annotations ['blocks' ][0 ]['segments' ][0 ]
206206 for ann in (bl_ann , seg_ann ):
207207 ann .update (gconf .attrib )
208208
@@ -221,32 +221,43 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
221221 def _get_signal_t_start (self , block_index , seg_index , stream_index ):
222222 return 0.
223223
224- def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop , stream_index , channel_indexes ):
224+ def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop , stream_index ,
225+ channel_indexes ):
225226 stream_id = self .header ['signal_streams' ][stream_index ]['id' ]
226227
227228 raw_unit8 = self ._raw_memmap [i_start :i_stop ]
228-
229+
229230 num_chan = len (self ._mask_channels_bytes [stream_id ])
231+ re_order = None
230232 if channel_indexes is None :
231233 # no loop : entire stream mask
232234 stream_mask = self ._mask_streams [stream_id ]
233235 else :
234- # acculate mask
235- if isinstance (channel_indexes , slice ):
236+ # accumulate mask
237+ if isinstance (channel_indexes , slice ):
236238 chan_inds = np .arange (num_chan )[channel_indexes ]
237239 else :
238240 chan_inds = channel_indexes
241+
242+ if np .any (np .diff (channel_indexes ) < 0 ):
243+ # handle channel are not ordered
244+ sorted_channel_indexes = np .sort (channel_indexes )
245+ re_order = np .array ([list (sorted_channel_indexes ).index (ch )
246+ for ch in channel_indexes ])
247+
239248 stream_mask = np .zeros (raw_unit8 .shape [1 ], dtype = 'bool' )
240249 for chan_ind in chan_inds :
241250 chan_mask = self ._mask_channels_bytes [stream_id ][chan_ind ]
242251 stream_mask |= chan_mask
243- # TODO : make a fix when "channel_indexes" arein wring order.
244252
245253 # this do a copy from memmap to memory
246254 raw_unit8_mask = raw_unit8 [:, stream_mask ]
247255 shape = raw_unit8_mask .shape
248256 shape = (shape [0 ], shape [1 ] // 2 )
249257 # reshape the and re type by view
250258 raw_unit16 = raw_unit8_mask .flatten ().view ('int16' ).reshape (shape )
251-
259+
260+ if re_order is not None :
261+ raw_unit16 = raw_unit16 [:, re_order ]
262+
252263 return raw_unit16
0 commit comments