Skip to content

Commit a827909

Browse files
committed
rawio improvement : nixio_fr
1 parent b459992 commit a827909

File tree

1 file changed

+60
-52
lines changed

1 file changed

+60
-52
lines changed

neo/rawio/nixrawio.py

Lines changed: 60 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
Author: Chek Yin Choi
88
"""
99

10-
from .baserawio import (BaseRawIO, _signal_channel_dtype,
11-
_spike_channel_dtype, _event_channel_dtype)
10+
from .baserawio import (BaseRawIO, _signal_channel_dtype, _signal_stream_dtype,
11+
_spike_channel_dtype, _event_channel_dtype)
12+
1213
from ..io.nixio import NixIO
1314
from ..io.nixio import check_nix_version
1415
import numpy as np
@@ -48,8 +49,9 @@ def _source_name(self):
4849

4950
def _parse_header(self):
5051
self.file = nix.File.open(self.filename, nix.FileMode.ReadOnly)
51-
sig_channels = []
52+
signal_channels = []
5253
size_list = []
54+
stream_ids = []
5355
for bl in self.file.blocks:
5456
for seg in bl.groups:
5557
for da_idx, da in enumerate(seg.data_arrays):
@@ -62,20 +64,22 @@ def _parse_header(self):
6264
da_leng = da.size
6365
if da_leng not in size_list:
6466
size_list.append(da_leng)
65-
group_id = 0
66-
for sid, li_leng in enumerate(size_list):
67-
if li_leng == da_leng:
68-
group_id = sid
69-
# very important! group_id use to store
70-
# channel groups!!!
71-
# use only for different signal length
67+
stream_ids.append(str(len(size_list)))
68+
# very important! group_id use to store
69+
# channel groups!!!
70+
# use only for different signal length
71+
stream_index = size_list.index(da_leng)
72+
stream_id = stream_ids[stream_index]
7273
gain = 1
7374
offset = 0.
74-
sig_channels.append((ch_name, chan_id, sr, dtype,
75-
units, gain, offset, group_id))
75+
signal_channels.append((ch_name, chan_id, sr, dtype,
76+
units, gain, offset, stream_id))
7677
break
7778
break
78-
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
79+
signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
80+
signal_streams = np.zeros(len(stream_ids), dtype=_signal_stream_dtype)
81+
signal_streams['id'] = stream_ids
82+
signal_streams['name'] = ''
7983

8084
spike_channels = []
8185
unit_name = ""
@@ -183,7 +187,8 @@ def _parse_header(self):
183187
self.header = {}
184188
self.header['nb_block'] = len(self.file.blocks)
185189
self.header['nb_segment'] = [len(bl.groups) for bl in self.file.blocks]
186-
self.header['signal_channels'] = sig_channels
190+
self.header['signal_streams'] = signal_streams
191+
self.header['signal_channels'] = signal_channels
187192
self.header['spike_channels'] = spike_channels
188193
self.header['event_channels'] = event_channels
189194

@@ -192,27 +197,28 @@ def _parse_header(self):
192197
bl_ann = self.raw_annotations['blocks'][blk_idx]
193198
props = blk.metadata.inherited_properties()
194199
bl_ann.update(self._filter_properties(props, "block"))
195-
for grp_idx, grp in enumerate(blk.groups):
200+
for grp_idx, group in enumerate(blk.groups):
196201
seg_ann = bl_ann['segments'][grp_idx]
197-
props = grp.metadata.inherited_properties()
202+
props = group.metadata.inherited_properties()
198203
seg_ann.update(self._filter_properties(props, "segment"))
199-
sig_idx = 0
200-
groupdas = NixIO._group_signals(grp.data_arrays)
201-
for nix_name, signals in groupdas.items():
202-
da = signals[0]
203-
if da.type == 'neo.analogsignal' and seg_ann['signals']:
204-
# collect and group DataArrays
205-
sig_ann = seg_ann['signals'][sig_idx]
206-
sig_chan_ann = self.raw_annotations['signal_channels'][sig_idx]
207-
props = da.metadata.inherited_properties()
208-
sig_ann.update(self._filter_properties(props, 'analogsignal'))
209-
sig_chan_ann.update(self._filter_properties(props, 'analogsignal'))
210-
sig_idx += 1
204+
# TODO handle annotation at stream level
205+
# sig_idx = 0
206+
# groupdas = NixIO._group_signals(grp.data_arrays)
207+
# for nix_name, signals in groupdas.items():
208+
#  da = signals[0]
209+
#  if da.type == 'neo.analogsignal' and seg_ann['signals']:
210+
#  # collect and group DataArrays
211+
#  sig_ann = seg_ann['signals'][sig_idx]
212+
#  sig_chan_ann = self.raw_annotations['signal_channels'][sig_idx]
213+
#  props = da.metadata.inherited_properties()
214+
#  sig_ann.update(self._filter_properties(props, 'analogsignal'))
215+
#  sig_chan_ann.update(self._filter_properties(props, 'analogsignal'))
216+
#  sig_idx += 1
211217
sp_idx = 0
212218
ev_idx = 0
213-
for mt in grp.multi_tags:
214-
if mt.type == 'neo.spiketrain' and seg_ann['units']:
215-
st_ann = seg_ann['units'][sp_idx]
219+
for mt in group.multi_tags:
220+
if mt.type == 'neo.spiketrain' and seg_ann['spikes']:
221+
st_ann = seg_ann['spikes'][sp_idx]
216222
props = mt.metadata.inherited_properties()
217223
st_ann.update(self._filter_properties(props, 'spiketrain'))
218224
sp_idx += 1
@@ -225,11 +231,11 @@ def _parse_header(self):
225231
event_ann.update(self._filter_properties(props, 'event'))
226232
ev_idx += 1
227233

228-
# populate ChannelIndex annotations
229-
for srcidx, source in enumerate(blk.sources):
230-
chx_ann = self.raw_annotations["signal_channels"][srcidx]
231-
props = source.metadata.inherited_properties()
232-
chx_ann.update(self._filter_properties(props, "channelindex"))
234+
#~ # populate ChannelIndex annotations
235+
#~ for srcidx, source in enumerate(blk.sources):
236+
#~ chx_ann = self.raw_annotations["signal_channels"][srcidx]
237+
#~ props = source.metadata.inherited_properties()
238+
#~ chx_ann.update(self._filter_properties(props, "channelindex"))
233239

234240
def _segment_t_start(self, block_index, seg_index):
235241
t_start = 0
@@ -245,41 +251,43 @@ def _segment_t_stop(self, block_index, seg_index):
245251
t_stop = mt.metadata['t_stop']
246252
return t_stop
247253

248-
def _get_signal_size(self, block_index, seg_index, channel_indexes):
249-
if channel_indexes is None:
250-
channel_indexes = list(range(self.header['signal_channels'].size))
254+
def _get_signal_size(self, block_index, seg_index, stream_index):
255+
stream_id = self.header['signal_streams'][stream_index]['id']
256+
keep = self.header['signal_channels']['stream_id'] == stream_id
257+
channel_indexes, = np.nonzero(keep)
251258
ch_idx = channel_indexes[0]
252259
block = self.da_list['blocks'][block_index]
253260
segment = block['segments'][seg_index]
254261
size = segment['data_size'][ch_idx]
255262
return size # size is per signal, not the sum of all channel_indexes
256263

257-
def _get_signal_t_start(self, block_index, seg_index, channel_indexes):
258-
if channel_indexes is None:
259-
channel_indexes = list(range(self.header['signal_channels'].size))
264+
def _get_signal_t_start(self, block_index, seg_index, stream_index):
265+
stream_id = self.header['signal_streams'][stream_index]['id']
266+
keep = self.header['signal_channels']['stream_id'] == stream_id
267+
channel_indexes, = np.nonzero(keep)
260268
ch_idx = channel_indexes[0]
261269
block = self.file.blocks[block_index]
262270
das = [da for da in block.groups[seg_index].data_arrays]
263271
da = das[ch_idx]
264272
sig_t_start = float(da.metadata['t_start'])
265273
return sig_t_start # assume same group_id always same t_start
266274

267-
def _get_analogsignal_chunk(self, block_index, seg_index,
268-
i_start, i_stop, channel_indexes):
269-
if channel_indexes is None:
270-
channel_indexes = list(range(self.header['signal_channels'].size))
275+
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
276+
stream_index, channel_indexes):
277+
stream_id = self.header['signal_streams'][stream_index]['id']
278+
keep = self.header['signal_channels']['stream_id'] == stream_id
279+
global_channel_indexes, = np.nonzero(keep)
280+
if channel_indexes is not None:
281+
global_channel_indexes = global_channel_indexes[channel_indexes]
282+
271283
if i_start is None:
272284
i_start = 0
273285
if i_stop is None:
274-
block = self.da_list['blocks'][block_index]
275-
segment = block['segments'][seg_index]
276-
for c in channel_indexes:
277-
i_stop = segment['data_size'][c]
278-
break
286+
i_stop = self.get_signal_size(block_index, seg_index, stream_index)
279287

280288
raw_signals_list = []
281289
da_list = self.da_list['blocks'][block_index]['segments'][seg_index]
282-
for idx in channel_indexes:
290+
for idx in global_channel_indexes:
283291
da = da_list['data'][idx]
284292
raw_signals_list.append(da[i_start:i_stop])
285293

0 commit comments

Comments
 (0)