Skip to content

Commit faabc78

Browse files
committed
expose explicit stream selection.
1 parent 61b46d3 commit faabc78

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

neo/rawio/spikegadgetsrawio.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,21 @@ class SpikeGadgetsRawIO(BaseRawIO):
2323
extensions = ['rec']
2424
rawmode = 'one-file'
2525

26-
def __init__(self, filename=''):
26+
def __init__(self, filename='', selected_streams=None):
27+
"""
28+
29+
filename: str
30+
filename ".rec"
31+
32+
streams: None, list, str
33+
sublist of streams to load/expose to API
34+
uselfull for spikeextractor when one stream isneed.
35+
For instance streams = ['ECU', 'trodes']
36+
'trodes' is name for ephy channel (ntrodes)
37+
"""
2738
BaseRawIO.__init__(self)
2839
self.filename = filename
40+
self.selected_streams = selected_streams
2941

3042
def _source_name(self):
3143
return self.filename
@@ -71,9 +83,10 @@ def _parse_header(self):
7183
packet_size += 4
7284

7385
packet_size += 2 * num_ephy_channels
74-
86+
7587
# read the binary part lazily
7688
raw_memmap = np.memmap(self.filename, mode='r', offset=header_size, dtype='<u1')
89+
7790
num_packet = raw_memmap.size // packet_size
7891
raw_memmap = raw_memmap[:num_packet*packet_size]
7992
self._raw_memmap = raw_memmap.reshape(-1, packet_size)
@@ -151,6 +164,19 @@ def _parse_header(self):
151164
signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)
152165
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
153166

167+
168+
# remove some stream if no wanted
169+
if self.selected_streams is not None:
170+
if isinstance(self.selected_streams, str):
171+
self.selected_streams = [self.selected_streams]
172+
assert isinstance(self.selected_streams, list)
173+
174+
keep = np.in1d(signal_streams['id'], self.selected_streams)
175+
signal_streams = signal_streams[keep]
176+
177+
keep = np.in1d(signal_channels['stream_id'], self.selected_streams)
178+
signal_channels = signal_channels[keep]
179+
154180
# No events channels
155181
event_channels = []
156182
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
@@ -201,7 +227,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
201227
stream_mask = self._mask_streams[stream_id]
202228
else:
203229
# acculate mask
204-
if instance(channel_indexes, slice):
230+
if isinstance(channel_indexes, slice):
205231
chan_inds = np.arange(num_chan)[channel_indexes]
206232
else:
207233
chan_inds = channel_indexes

0 commit comments

Comments
 (0)