Skip to content

Commit 588dc46

Browse files
committed
wip spikegadgetsio
1 parent 3e4178a commit 588dc46

File tree

2 files changed

+54
-44
lines changed

2 files changed

+54
-44
lines changed

neo/rawio/spikegadgetsrawio.py

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
99
Note :
1010
* this file format have multiple version. news version include the gain for scaling.
11-
The actual implementation do not contain this feature because we don't have
12-
files to test this. So now the gain is "hardcoded"
11+
The actual implementation do not contain this feature because we don't have
12+
files to test this. So now the gain is "hardcoded" to 1. and so units
13+
is not handled correctly.
1314
1415
The file ".rec" have :
1516
* a fist part in text with xml informations
@@ -18,22 +19,21 @@
1819
Author: Samuel Garcia
1920
"""
2021
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
21-
_spike_channel_dtype, _event_channel_dtype)
22+
_spike_channel_dtype, _event_channel_dtype)
2223

2324
import numpy as np
2425

2526
from xml.etree import ElementTree
2627

28+
2729
class SpikeGadgetsRawIO(BaseRawIO):
2830
extensions = ['rec']
2931
rawmode = 'one-file'
3032

3133
def __init__(self, filename='', selected_streams=None):
3234
"""
33-
3435
filename: str
3536
filename ".rec"
36-
3737
selected_streams: None, list, str
3838
sublist of streams to load/expose to API
3939
uselfull for spikeextractor when one stream isneed.
@@ -48,7 +48,6 @@ def _source_name(self):
4848
return self.filename
4949

5050
def _parse_header(self):
51-
5251
# parse file until "</Configuration>"
5352
header_size = None
5453
with open(self.filename, mode='rb') as f:
@@ -60,19 +59,19 @@ def _parse_header(self):
6059

6160
if header_size is None:
6261
ValueError("SpikeGadgets : the xml header do not contain </Configuration>")
63-
62+
6463
f.seek(0)
6564
header_txt = f.read(header_size).decode('utf8')
66-
65+
6766
# explore xml header
6867
root = ElementTree.fromstring(header_txt)
6968
gconf = sr = root.find('GlobalConfiguration')
7069
hconf = root.find('HardwareConfiguration')
7170
sconf = root.find('SpikeConfiguration')
72-
71+
7372
self._sampling_rate = float(hconf.attrib['samplingRate'])
7473
num_ephy_channels = int(hconf .attrib['numChannels'])
75-
74+
7675
# explore sub stream and count packet size
7776
# first bytes is 0x55
7877
packet_size = 1
@@ -82,83 +81,85 @@ def _parse_header(self):
8281
num_bytes = int(device.attrib['numBytes'])
8382
stream_bytes[stream_id] = packet_size
8483
packet_size += num_bytes
85-
84+
8685
# timesteamps 4 uint32
8786
self._timestamp_byte = packet_size
8887
packet_size += 4
89-
88+
9089
packet_size += 2 * num_ephy_channels
9190

9291
# read the binary part lazily
9392
raw_memmap = np.memmap(self.filename, mode='r', offset=header_size, dtype='<u1')
9493

9594
num_packet = raw_memmap.size // packet_size
96-
raw_memmap = raw_memmap[:num_packet*packet_size]
95+
raw_memmap = raw_memmap[:num_packet * packet_size]
9796
self._raw_memmap = raw_memmap.reshape(-1, packet_size)
9897

9998
# create signal channels
10099
stream_ids = []
101100
signal_streams = []
102101
signal_channels = []
103-
102+
104103
# walk in xml device and keep only "analog" one
105-
self._mask_channels_bytes = {}
104+
self._mask_channels_bytes = {}
106105
for device in hconf:
107106
stream_id = device.attrib['name']
108107
for channel in device:
109108

110109
if 'interleavedDataIDByte' in channel.attrib:
111-
# TODO deal with "headstageSensor" wich have interleaved
110+
# TODO LATER: deal with "headstageSensor" wich have interleaved
112111
continue
113112

114113
if channel.attrib['dataType'] == 'analog':
115-
114+
116115
if stream_id not in stream_ids:
117116
stream_ids.append(stream_id)
118117
stream_name = stream_id
119118
signal_streams.append((stream_name, stream_id))
120119
self._mask_channels_bytes[stream_id] = []
121-
120+
122121
name = channel.attrib['id']
123122
chan_id = channel.attrib['id']
124123
dtype = 'int16'
125-
units = 'uV' # TODO check where is the info
126-
gain = 1. # TODO check where is the info
124+
# TODO LATER : handle gain correctly according the file version
125+
units = ''
126+
gain = 1.
127127
offset = 0.
128128
signal_channels.append((name, chan_id, self._sampling_rate, 'int16',
129129
units, gain, offset, stream_id))
130-
130+
131131
num_bytes = stream_bytes[stream_id] + int(channel.attrib['startByte'])
132132
chan_mask = np.zeros(packet_size, dtype='bool')
133133
chan_mask[num_bytes] = True
134-
chan_mask[num_bytes+1] = True
134+
chan_mask[num_bytes + 1] = True
135135
self._mask_channels_bytes[stream_id].append(chan_mask)
136-
136+
137137
if num_ephy_channels > 0:
138138
stream_id = 'trodes'
139139
stream_name = stream_id
140140
signal_streams.append((stream_name, stream_id))
141141
self._mask_channels_bytes[stream_id] = []
142-
142+
143143
chan_ind = 0
144144
for trode in sconf:
145-
for schan in trode:
145+
for schan in trode:
146146
name = 'trode' + trode.attrib['id'] + 'chan' + schan.attrib['hwChan']
147147
chan_id = schan.attrib['hwChan']
148-
units = 'uV' # TODO check where is the info
149-
gain = 1. # TODO check where is the info
148+
# TODO LATER : handle gain correctly according the file version
149+
units = ''
150+
gain = 1.
150151
offset = 0.
151152
signal_channels.append((name, chan_id, self._sampling_rate, 'int16',
152153
units, gain, offset, stream_id))
153-
154+
154155
chan_mask = np.zeros(packet_size, dtype='bool')
155156
num_bytes = packet_size - 2 * num_ephy_channels + 2 * chan_ind
156157
chan_mask[num_bytes] = True
157-
chan_mask[num_bytes+1] = True
158+
chan_mask[num_bytes + 1] = True
158159
self._mask_channels_bytes[stream_id].append(chan_mask)
159-
160+
160161
chan_ind += 1
161-
162+
162163
# make mask as array (used in _get_analogsignal_chunk(...))
163164
self._mask_streams = {}
164165
for stream_id, l in self._mask_channels_bytes.items():
@@ -168,20 +169,19 @@ def _parse_header(self):
168169

169170
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
170171
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
171-
172-
172+
173173
# remove some stream if no wanted
174174
if self.selected_streams is not None:
175175
if isinstance(self.selected_streams, str):
176176
self.selected_streams = [self.selected_streams]
177177
assert isinstance(self.selected_streams, list)
178-
178+
179179
keep = np.in1d(signal_streams['id'], self.selected_streams)
180180
signal_streams = signal_streams[keep]
181181

182182
keep = np.in1d(signal_channels['stream_id'], self.selected_streams)
183183
signal_channels = signal_channels[keep]
184-
184+
185185
# No events channels
186186
event_channels = []
187187
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
@@ -202,7 +202,7 @@ def _parse_header(self):
202202
self._generate_minimal_annotations()
203203
# info from GlobalConfiguration in xml are copied to block and seg annotations
204204
bl_ann = self.raw_annotations['blocks'][0]
205-
seg_ann = self.raw_annotations['blocks'][0]['segments'][0]
205+
seg_ann = self.raw_annotations['blocks'][0]['segments'][0]
206206
for ann in (bl_ann, seg_ann):
207207
ann.update(gconf.attrib)
208208

@@ -221,32 +221,43 @@ def _get_signal_size(self, block_index, seg_index, stream_index):
221221
def _get_signal_t_start(self, block_index, seg_index, stream_index):
222222
return 0.
223223

224-
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
224+
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index,
225+
channel_indexes):
225226
stream_id = self.header['signal_streams'][stream_index]['id']
226227

227228
raw_unit8 = self._raw_memmap[i_start:i_stop]
228-
229+
229230
num_chan = len(self._mask_channels_bytes[stream_id])
231+
re_order = None
230232
if channel_indexes is None:
231233
# no loop : entire stream mask
232234
stream_mask = self._mask_streams[stream_id]
233235
else:
234-
# acculate mask
235-
if isinstance(channel_indexes, slice):
236+
# accumulate mask
237+
if isinstance(channel_indexes, slice):
236238
chan_inds = np.arange(num_chan)[channel_indexes]
237239
else:
238240
chan_inds = channel_indexes
241+
242+
if np.any(np.diff(channel_indexes) < 0):
243+
# handle channel are not ordered
244+
sorted_channel_indexes = np.sort(channel_indexes)
245+
re_order = np.array([list(sorted_channel_indexes).index(ch)
246+
for ch in channel_indexes])
247+
239248
stream_mask = np.zeros(raw_unit8.shape[1], dtype='bool')
240249
for chan_ind in chan_inds:
241250
chan_mask = self._mask_channels_bytes[stream_id][chan_ind]
242251
stream_mask |= chan_mask
243-
# TODO : make a fix when "channel_indexes" arein wring order.
244252

245253
# this do a copy from memmap to memory
246254
raw_unit8_mask = raw_unit8[:, stream_mask]
247255
shape = raw_unit8_mask.shape
248256
shape = (shape[0], shape[1] // 2)
249257
# reshape the and re type by view
250258
raw_unit16 = raw_unit8_mask.flatten().view('int16').reshape(shape)
251-
259+
260+
if re_order is not None:
261+
raw_unit16 = raw_unit16[:, re_order]
262+
252263
return raw_unit16

neo/test/iotest/test_spikegadgetsio.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
simport unittest
1+
import unittest
22

33
from neo.io import SpikeGadgetsIO
44
from neo.test.iotest.common_io_test import BaseTestIO
@@ -13,6 +13,5 @@ class TestSpikeGadgetsIO(BaseTestIO, unittest.TestCase, ):
1313
]
1414

1515

16-
1716
if __name__ == "__main__":
1817
unittest.main()

0 commit comments

Comments
 (0)