@@ -23,9 +23,21 @@ class SpikeGadgetsRawIO(BaseRawIO):
2323 extensions = ['rec' ]
2424 rawmode = 'one-file'
2525
26- def __init__ (self , filename = '' ):
26+ def __init__ (self , filename = '' , selected_streams = None ):
27+ """
28+
29+ filename: str
30+ filename ".rec"
31+
32+ streams: None, list, str
33+ sublist of streams to load/expose to API
34+ uselfull for spikeextractor when one stream isneed.
35+ For instance streams = ['ECU', 'trodes']
36+ 'trodes' is name for ephy channel (ntrodes)
37+ """
2738 BaseRawIO .__init__ (self )
2839 self .filename = filename
40+ self .selected_streams = selected_streams
2941
3042 def _source_name (self ):
3143 return self .filename
@@ -71,9 +83,10 @@ def _parse_header(self):
7183 packet_size += 4
7284
7385 packet_size += 2 * num_ephy_channels
74-
86+
7587 # read the binary part lazily
7688 raw_memmap = np .memmap (self .filename , mode = 'r' , offset = header_size , dtype = '<u1' )
89+
7790 num_packet = raw_memmap .size // packet_size
7891 raw_memmap = raw_memmap [:num_packet * packet_size ]
7992 self ._raw_memmap = raw_memmap .reshape (- 1 , packet_size )
@@ -151,6 +164,19 @@ def _parse_header(self):
151164 signal_streams = np .array (signal_streams , dtype = _signal_stream_dtype )
152165 signal_channels = np .array (signal_channels , dtype = _signal_channel_dtype )
153166
167+
168+ # remove some stream if no wanted
169+ if self .selected_streams is not None :
170+ if isinstance (self .selected_streams , str ):
171+ self .selected_streams = [self .selected_streams ]
172+ assert isinstance (self .selected_streams , list )
173+
174+ keep = np .in1d (signal_streams ['id' ], self .selected_streams )
175+ signal_streams = signal_streams [keep ]
176+
177+ keep = np .in1d (signal_channels ['stream_id' ], self .selected_streams )
178+ signal_channels = signal_channels [keep ]
179+
154180 # No events channels
155181 event_channels = []
156182 event_channels = np .array (event_channels , dtype = _event_channel_dtype )
@@ -201,7 +227,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, strea
201227 stream_mask = self ._mask_streams [stream_id ]
202228 else :
203229 # acculate mask
204- if instance (channel_indexes , slice ):
230+ if isinstance (channel_indexes , slice ):
205231 chan_inds = np .arange (num_chan )[channel_indexes ]
206232 else :
207233 chan_inds = channel_indexes
0 commit comments