Skip to content

Commit 33729b5

Browse files
Merge pull request #954 from samuelgarcia/spikegadgets_implementation
Spikegadgets implementation
2 parents 7254820 + 45cbd44 commit 33729b5

File tree

6 files changed

+327
-0
lines changed

6 files changed

+327
-0
lines changed

neo/io/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
* :attr:`RawBinarySignalIO`
5353
* :attr:`RawMCSIO`
5454
* :attr:`Spike2IO`
55+
* :attr:`SpikeGadgetsIO`
5556
* :attr:`SpikeGLXIO`
5657
* :attr:`StimfitIO`
5758
* :attr:`TdtIO`
@@ -216,6 +217,10 @@
216217
217218
.. autoattribute:: extensions
218219
220+
.. autoclass:: SpikeGadgetsIO
221+
222+
.. autoattribute:: extensions
223+
219224
.. autoclass:: SpikeGLXIO
220225
221226
.. autoattribute:: extensions
@@ -300,6 +305,7 @@
300305
from neo.io.rawbinarysignalio import RawBinarySignalIO
301306
from neo.io.rawmcsio import RawMCSIO
302307
from neo.io.spike2io import Spike2IO
308+
from neo.io.spikegadgetsio import SpikeGadgetsIO
303309
from neo.io.spikeglxio import SpikeGLXIO
304310
from neo.io.stimfitio import StimfitIO
305311
from neo.io.tdtio import TdtIO
@@ -348,6 +354,7 @@
348354
RawBinarySignalIO,
349355
RawMCSIO,
350356
Spike2IO,
357+
SpikeGadgetsIO,
351358
SpikeGLXIO,
352359
StimfitIO,
353360
TdtIO,

neo/io/spikegadgetsio.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from neo.io.basefromrawio import BaseFromRaw
2+
from neo.rawio.spikegadgetsrawio import SpikeGadgetsRawIO
3+
4+
5+
class SpikeGadgetsIO(SpikeGadgetsRawIO, BaseFromRaw):
6+
__doc__ = SpikeGadgetsRawIO.__doc__
7+
def __init__(self, filename):
8+
SpikeGadgetsRawIO.__init__(self, filename=filename)
9+
BaseFromRaw.__init__(self, filename)

neo/rawio/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
* :attr:`RawBinarySignalRawIO`
3535
* :attr:`RawMCSRawIO`
3636
* :attr:`Spike2RawIO`
37+
* :attr:`SpikeGadgetsRawIO`
3738
* :attr:`SpikeGLXRawIO`
3839
* :attr:`TdtRawIO`
3940
* :attr:`WinEdrRawIO`
@@ -128,6 +129,10 @@
128129
129130
.. autoattribute:: extensions
130131
132+
.. autoclass:: neo.rawio.SpikeGadgetsRawIO
133+
134+
.. autoattribute:: extensions
135+
131136
.. autoclass:: neo.rawio.SpikeGLXRawIO
132137
133138
.. autoattribute:: extensions
@@ -170,6 +175,7 @@
170175
from neo.rawio.rawbinarysignalrawio import RawBinarySignalRawIO
171176
from neo.rawio.rawmcsrawio import RawMCSRawIO
172177
from neo.rawio.spike2rawio import Spike2RawIO
178+
from neo.rawio.spikegadgetsrawio import SpikeGadgetsRawIO
173179
from neo.rawio.spikeglxrawio import SpikeGLXRawIO
174180
from neo.rawio.tdtrawio import TdtRawIO
175181
from neo.rawio.winedrrawio import WinEdrRawIO
@@ -198,6 +204,7 @@
198204
RawBinarySignalRawIO,
199205
RawMCSRawIO,
200206
Spike2RawIO,
207+
SpikeGadgetsRawIO,
201208
SpikeGLXRawIO,
202209
TdtRawIO,
203210
WinEdrRawIO,

neo/rawio/spikegadgetsrawio.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""
2+
Class for reading spikegadgets files.
3+
Only continuous signals are supported at the moment.
4+
5+
https://spikegadgets.com/spike-products/
6+
7+
Documentation of the format:
8+
https://bitbucket.org/mkarlsso/trodes/wiki/Configuration
9+
10+
Note :
11+
* this file format have multiple version. news version include the gain for scaling.
12+
The actual implementation do not contain this feature because we don't have
13+
files to test this. So now the gain is "hardcoded" to 1. and so units
14+
is not handled correctly.
15+
16+
The ".rec" file format contains:
17+
* a first text part with information in an XML structure
18+
* a second part for the binary buffer
19+
20+
Author: Samuel Garcia
21+
"""
22+
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
23+
_spike_channel_dtype, _event_channel_dtype)
24+
25+
import numpy as np
26+
27+
from xml.etree import ElementTree
28+
29+
30+
class SpikeGadgetsRawIO(BaseRawIO):
31+
extensions = ['rec']
32+
rawmode = 'one-file'
33+
34+
def __init__(self, filename='', selected_streams=None):
35+
"""
36+
Class for reading spikegadgets files.
37+
Only continuous signals are supported at the moment.
38+
39+
Initialize a SpikeGadgetsRawIO for a single ".rec" file.
40+
41+
Args:
42+
filename: str
43+
The filename
44+
selected_streams: None, list, str
45+
sublist of streams to load/expose to API
46+
useful for spikeextractor when one stream only is needed.
47+
For instance streams = ['ECU', 'trodes']
48+
'trodes' is name for ephy channel (ntrodes)
49+
"""
50+
BaseRawIO.__init__(self)
51+
self.filename = filename
52+
self.selected_streams = selected_streams
53+
54+
def _source_name(self):
55+
return self.filename
56+
57+
def _parse_header(self):
58+
# parse file until "</Configuration>"
59+
header_size = None
60+
with open(self.filename, mode='rb') as f:
61+
while True:
62+
line = f.readline()
63+
if b"</Configuration>" in line:
64+
header_size = f.tell()
65+
break
66+
67+
if header_size is None:
68+
ValueError("SpikeGadgets: the xml header does not contain '</Configuration>'")
69+
70+
f.seek(0)
71+
header_txt = f.read(header_size).decode('utf8')
72+
73+
# explore xml header
74+
root = ElementTree.fromstring(header_txt)
75+
gconf = sr = root.find('GlobalConfiguration')
76+
hconf = root.find('HardwareConfiguration')
77+
sconf = root.find('SpikeConfiguration')
78+
79+
self._sampling_rate = float(hconf.attrib['samplingRate'])
80+
num_ephy_channels = int(hconf .attrib['numChannels'])
81+
82+
# explore sub stream and count packet size
83+
# first bytes is 0x55
84+
packet_size = 1
85+
stream_bytes = {}
86+
for device in hconf:
87+
stream_id = device.attrib['name']
88+
num_bytes = int(device.attrib['numBytes'])
89+
stream_bytes[stream_id] = packet_size
90+
packet_size += num_bytes
91+
92+
# timesteamps 4 uint32
93+
self._timestamp_byte = packet_size
94+
packet_size += 4
95+
96+
packet_size += 2 * num_ephy_channels
97+
98+
# read the binary part lazily
99+
raw_memmap = np.memmap(self.filename, mode='r', offset=header_size, dtype='<u1')
100+
101+
num_packet = raw_memmap.size // packet_size
102+
raw_memmap = raw_memmap[:num_packet * packet_size]
103+
self._raw_memmap = raw_memmap.reshape(-1, packet_size)
104+
105+
# create signal channels
106+
stream_ids = []
107+
signal_streams = []
108+
signal_channels = []
109+
110+
# walk in xml device and keep only "analog" one
111+
self._mask_channels_bytes = {}
112+
for device in hconf:
113+
stream_id = device.attrib['name']
114+
for channel in device:
115+
116+
if 'interleavedDataIDByte' in channel.attrib:
117+
# TODO LATER: deal with "headstageSensor" which have interleaved
118+
continue
119+
120+
if channel.attrib['dataType'] == 'analog':
121+
122+
if stream_id not in stream_ids:
123+
stream_ids.append(stream_id)
124+
stream_name = stream_id
125+
signal_streams.append((stream_name, stream_id))
126+
self._mask_channels_bytes[stream_id] = []
127+
128+
name = channel.attrib['id']
129+
chan_id = channel.attrib['id']
130+
dtype = 'int16'
131+
# TODO LATER : handle gain correctly according the file version
132+
units = ''
133+
gain = 1.
134+
offset = 0.
135+
signal_channels.append((name, chan_id, self._sampling_rate, 'int16',
136+
units, gain, offset, stream_id))
137+
138+
num_bytes = stream_bytes[stream_id] + int(channel.attrib['startByte'])
139+
chan_mask = np.zeros(packet_size, dtype='bool')
140+
chan_mask[num_bytes] = True
141+
chan_mask[num_bytes + 1] = True
142+
self._mask_channels_bytes[stream_id].append(chan_mask)
143+
144+
if num_ephy_channels > 0:
145+
stream_id = 'trodes'
146+
stream_name = stream_id
147+
signal_streams.append((stream_name, stream_id))
148+
self._mask_channels_bytes[stream_id] = []
149+
150+
chan_ind = 0
151+
for trode in sconf:
152+
for schan in trode:
153+
name = 'trode' + trode.attrib['id'] + 'chan' + schan.attrib['hwChan']
154+
chan_id = schan.attrib['hwChan']
155+
# TODO LATER : handle gain correctly according the file version
156+
units = ''
157+
gain = 1.
158+
offset = 0.
159+
signal_channels.append((name, chan_id, self._sampling_rate, 'int16',
160+
units, gain, offset, stream_id))
161+
162+
chan_mask = np.zeros(packet_size, dtype='bool')
163+
num_bytes = packet_size - 2 * num_ephy_channels + 2 * chan_ind
164+
chan_mask[num_bytes] = True
165+
chan_mask[num_bytes + 1] = True
166+
self._mask_channels_bytes[stream_id].append(chan_mask)
167+
168+
chan_ind += 1
169+
170+
# make mask as array (used in _get_analogsignal_chunk(...))
171+
self._mask_streams = {}
172+
for stream_id, l in self._mask_channels_bytes.items():
173+
mask = np.array(l)
174+
self._mask_channels_bytes[stream_id] = mask
175+
self._mask_streams[stream_id] = np.any(mask, axis=0)
176+
177+
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
178+
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
179+
180+
# remove some stream if no wanted
181+
if self.selected_streams is not None:
182+
if isinstance(self.selected_streams, str):
183+
self.selected_streams = [self.selected_streams]
184+
assert isinstance(self.selected_streams, list)
185+
186+
keep = np.in1d(signal_streams['id'], self.selected_streams)
187+
signal_streams = signal_streams[keep]
188+
189+
keep = np.in1d(signal_channels['stream_id'], self.selected_streams)
190+
signal_channels = signal_channels[keep]
191+
192+
# No events channels
193+
event_channels = []
194+
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
195+
196+
# No spikes channels
197+
spike_channels = []
198+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
199+
200+
# fille into header dict
201+
self.header = {}
202+
self.header['nb_block'] = 1
203+
self.header['nb_segment'] = [1]
204+
self.header['signal_streams'] = signal_streams
205+
self.header['signal_channels'] = signal_channels
206+
self.header['spike_channels'] = spike_channels
207+
self.header['event_channels'] = event_channels
208+
209+
self._generate_minimal_annotations()
210+
# info from GlobalConfiguration in xml are copied to block and seg annotations
211+
bl_ann = self.raw_annotations['blocks'][0]
212+
seg_ann = self.raw_annotations['blocks'][0]['segments'][0]
213+
for ann in (bl_ann, seg_ann):
214+
ann.update(gconf.attrib)
215+
216+
def _segment_t_start(self, block_index, seg_index):
217+
return 0.
218+
219+
def _segment_t_stop(self, block_index, seg_index):
220+
size = self._raw_memmap.shape[0]
221+
t_stop = size / self._sampling_rate
222+
return t_stop
223+
224+
def _get_signal_size(self, block_index, seg_index, stream_index):
225+
size = self._raw_memmap.shape[0]
226+
return size
227+
228+
def _get_signal_t_start(self, block_index, seg_index, stream_index):
229+
return 0.
230+
231+
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index,
232+
channel_indexes):
233+
stream_id = self.header['signal_streams'][stream_index]['id']
234+
235+
raw_unit8 = self._raw_memmap[i_start:i_stop]
236+
237+
num_chan = len(self._mask_channels_bytes[stream_id])
238+
re_order = None
239+
if channel_indexes is None:
240+
# no loop : entire stream mask
241+
stream_mask = self._mask_streams[stream_id]
242+
else:
243+
# accumulate mask
244+
if isinstance(channel_indexes, slice):
245+
chan_inds = np.arange(num_chan)[channel_indexes]
246+
else:
247+
chan_inds = channel_indexes
248+
249+
if np.any(np.diff(channel_indexes) < 0):
250+
# handle channel are not ordered
251+
sorted_channel_indexes = np.sort(channel_indexes)
252+
re_order = np.array([list(sorted_channel_indexes).index(ch)
253+
for ch in channel_indexes])
254+
255+
stream_mask = np.zeros(raw_unit8.shape[1], dtype='bool')
256+
for chan_ind in chan_inds:
257+
chan_mask = self._mask_channels_bytes[stream_id][chan_ind]
258+
stream_mask |= chan_mask
259+
260+
# this copies the data from the memmap into memory
261+
raw_unit8_mask = raw_unit8[:, stream_mask]
262+
shape = raw_unit8_mask.shape
263+
shape = (shape[0], shape[1] // 2)
264+
# reshape the and retype by view
265+
raw_unit16 = raw_unit8_mask.flatten().view('int16').reshape(shape)
266+
267+
if re_order is not None:
268+
raw_unit16 = raw_unit16[:, re_order]
269+
270+
return raw_unit16
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import unittest
2+
3+
from neo.io import SpikeGadgetsIO
4+
from neo.test.iotest.common_io_test import BaseTestIO
5+
6+
7+
class TestSpikeGadgetsIO(BaseTestIO, unittest.TestCase, ):
8+
ioclass = SpikeGadgetsIO
9+
entities_to_download = ['spikegadgets']
10+
entities_to_test = [
11+
'spikegadgets/20210225_em8_minirec2_ac.rec',
12+
'spikegadgets/W122_06_09_2019_1_fromSD.rec'
13+
]
14+
15+
16+
if __name__ == "__main__":
17+
unittest.main()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import unittest
2+
3+
from neo.rawio.spikegadgetsrawio import SpikeGadgetsRawIO
4+
from neo.test.rawiotest.common_rawio_test import BaseTestRawIO
5+
6+
7+
class TestSpikeGadgetsRawIO(BaseTestRawIO, unittest.TestCase, ):
8+
rawioclass = SpikeGadgetsRawIO
9+
entities_to_download = ['spikegadgets']
10+
entities_to_test = [
11+
'spikegadgets/20210225_em8_minirec2_ac.rec',
12+
'spikegadgets/W122_06_09_2019_1_fromSD.rec'
13+
]
14+
15+
16+
if __name__ == "__main__":
17+
unittest.main()

0 commit comments

Comments
 (0)