Skip to content

Commit 33b05b6

Browse files
authored
Merge pull request #1258 from h-mayorquin/improve_mearec
Add option to `MEArecRawIO` for loading only recordings or only sorting data
2 parents a351f4e + 1863d55 commit 33b05b6

File tree

4 files changed

+128
-37
lines changed

4 files changed

+128
-37
lines changed

neo/io/mearecio.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ class MEArecIO(MEArecRawIO, BaseFromRaw):
66
__doc__ = MEArecRawIO.__doc__
77
mode = 'file'
88

9-
def __init__(self, filename):
10-
MEArecRawIO.__init__(self, filename=filename)
9+
def __init__(self, filename, load_spiketrains=True, load_analogsignal=True):
10+
MEArecRawIO.__init__(self,
11+
filename=filename,
12+
load_spiketrains=load_spiketrains,
13+
load_analogsignal=load_analogsignal
14+
)
1115
BaseFromRaw.__init__(self, filename)

neo/rawio/baserawio.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,18 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
550550
np.ndarray and are contiguous
551551
:return: array with raw signal samples
552552
"""
553+
554+
signal_streams = self.header['signal_streams']
555+
signal_channels = self.header['signal_channels']
556+
no_signal_streams = signal_streams.size == 0
557+
no_channels = signal_channels.size == 0
558+
if no_signal_streams or no_channels:
559+
error_message = (
560+
"get_analogsignal_chunk can't be called on a file with no signal streams or channels."
561+
"Double check that your file contains signal streams and channels."
562+
)
563+
raise AttributeError(error_message)
564+
553565
stream_index = self._get_stream_index_from_arg(stream_index)
554566
channel_indexes = self._get_channel_indexes(stream_index, channel_indexes,
555567
channel_names, channel_ids)
@@ -579,7 +591,7 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
579591
channel_indexes=None, channel_names=None, channel_ids=None):
580592
"""
581593
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
582-
returned by a call to get_analog_signal_chunk. The channels are specified either by
594+
returned by a call to get_analogsignal_chunk. The channels are specified either by
583595
channel_names, if provided, otherwise by channel_ids, if provided, otherwise by
584596
channel_indexes, if provided, otherwise all channels are selected.
585597

neo/rawio/mearecrawio.py

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ class MEArecRawIO(BaseRawIO):
2020
"""
2121
Class for "reading" fake data from a MEArec file.
2222
23+
This class provides a convenient way to read data from a MEArec file.
24+
25+
Parameters
26+
----------
27+
filename : str
28+
The filename of the MEArec file to read.
29+
load_spiketrains : bool, optional
30+
Whether or not to load spike train data. Defaults to `True`.
31+
load_analogsignal : bool, optional
32+
Whether or not to load continuous recording data. Defaults to `True`.
33+
34+
2335
Usage:
2436
>>> import neo.rawio
2537
>>> r = neo.rawio.MEArecRawIO(filename='mearec.h5')
@@ -36,52 +48,75 @@ class MEArecRawIO(BaseRawIO):
3648
extensions = ['h5']
3749
rawmode = 'one-file'
3850

39-
def __init__(self, filename=''):
51+
def __init__(self, filename='', load_spiketrains=True, load_analogsignal=True):
4052
BaseRawIO.__init__(self)
4153
self.filename = filename
42-
54+
self.load_spiketrains = load_spiketrains
55+
self.load_analogsignal = load_analogsignal
56+
4357
def _source_name(self):
4458
return self.filename
4559

4660
def _parse_header(self):
61+
load = ["channel_positions"]
62+
if self.load_analogsignal:
63+
load.append("recordings")
64+
if self.load_spiketrains:
65+
load.append("spiketrains")
66+
4767
import MEArec as mr
4868
self._recgen = mr.load_recordings(recordings=self.filename, return_h5_objects=True,
4969
check_suffix=False,
50-
load=['recordings', 'spiketrains', 'channel_positions'],
70+
load=load,
5171
load_waveforms=False)
52-
self._sampling_rate = self._recgen.info['recordings']['fs']
53-
self._recordings = self._recgen.recordings
54-
self._num_frames, self._num_channels = self._recordings.shape
55-
56-
signal_streams = np.array([('Signals', '0')], dtype=_signal_stream_dtype)
5772

73+
self.info_dict = deepcopy(self._recgen.info)
74+
self.channel_positions = self._recgen.channel_positions
75+
if self.load_analogsignal:
76+
self._recordings = self._recgen.recordings
77+
if self.load_spiketrains:
78+
self._spiketrains = self._recgen.spiketrains
79+
80+
self._sampling_rate = self.info_dict['recordings']['fs']
81+
self.duration_seconds = self.info_dict["recordings"]["duration"]
82+
self._num_frames = int(self._sampling_rate * self.duration_seconds)
83+
self._num_channels = self.channel_positions.shape[0]
84+
self._dtype = self.info_dict["recordings"]["dtype"]
85+
86+
signals = [('Signals', '0')] if self.load_analogsignal else []
87+
signal_streams = np.array(signals, dtype=_signal_stream_dtype)
88+
89+
5890
sig_channels = []
59-
for c in range(self._num_channels):
60-
ch_name = 'ch{}'.format(c)
61-
chan_id = str(c + 1)
62-
sr = self._sampling_rate # Hz
63-
dtype = self._recordings.dtype
64-
units = 'uV'
65-
gain = 1.
66-
offset = 0.
67-
stream_id = '0'
68-
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
91+
if self.load_analogsignal:
92+
for c in range(self._num_channels):
93+
ch_name = 'ch{}'.format(c)
94+
chan_id = str(c + 1)
95+
sr = self._sampling_rate # Hz
96+
dtype = self._dtype
97+
units = 'uV'
98+
gain = 1.
99+
offset = 0.
100+
stream_id = '0'
101+
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
102+
69103
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
70104

71105
# creating units channels
72106
spike_channels = []
73-
self._spiketrains = self._recgen.spiketrains
74-
for c in range(len(self._spiketrains)):
75-
unit_name = 'unit{}'.format(c)
76-
unit_id = '#{}'.format(c)
77-
# if spiketrains[c].waveforms is not None:
78-
wf_units = ''
79-
wf_gain = 1.
80-
wf_offset = 0.
81-
wf_left_sweep = 0
82-
wf_sampling_rate = self._sampling_rate
83-
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
84-
wf_offset, wf_left_sweep, wf_sampling_rate))
107+
if self.load_spiketrains:
108+
for c in range(len(self._spiketrains)):
109+
unit_name = 'unit{}'.format(c)
110+
unit_id = '#{}'.format(c)
111+
# if spiketrains[c].waveforms is not None:
112+
wf_units = ''
113+
wf_gain = 1.
114+
wf_offset = 0.
115+
wf_left_sweep = 0
116+
wf_sampling_rate = self._sampling_rate
117+
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
118+
wf_offset, wf_left_sweep, wf_sampling_rate))
119+
85120
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
86121

87122
event_channels = []
@@ -98,7 +133,7 @@ def _parse_header(self):
98133
self._generate_minimal_annotations()
99134
for block_index in range(1):
100135
bl_ann = self.raw_annotations['blocks'][block_index]
101-
bl_ann['mearec_info'] = deepcopy(self._recgen.info)
136+
bl_ann['mearec_info'] = self.info_dict
102137

103138
def _segment_t_start(self, block_index, seg_index):
104139
all_starts = [[0.]]
@@ -119,6 +154,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
119154

120155
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
121156
stream_index, channel_indexes):
157+
158+
if not self.load_analogsignal:
159+
raise AttributeError("Recordings not loaded. Set load_analogsignal=True in MEArecRawIO constructor")
160+
122161
if i_start is None:
123162
i_start = 0
124163
if i_stop is None:
@@ -127,23 +166,25 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
127166
if channel_indexes is None:
128167
channel_indexes = slice(self._num_channels)
129168
if isinstance(channel_indexes, slice):
130-
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
169+
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
131170
else:
132171
# sort channels because h5py neeeds sorted indexes
133172
if np.any(np.diff(channel_indexes) < 0):
134173
sorted_channel_indexes = np.sort(channel_indexes)
135174
sorted_idx = np.array([list(sorted_channel_indexes).index(ch)
136175
for ch in channel_indexes])
137-
raw_signals = self._recgen.recordings[i_start:i_stop, sorted_channel_indexes]
176+
raw_signals = self._recordings[i_start:i_stop, sorted_channel_indexes]
138177
raw_signals = raw_signals[:, sorted_idx]
139178
else:
140-
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
179+
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
141180
return raw_signals
142181

143182
def _spike_count(self, block_index, seg_index, unit_index):
183+
144184
return len(self._spiketrains[unit_index])
145185

146186
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
187+
147188
spike_timestamps = self._spiketrains[unit_index].times.magnitude
148189
if t_start is None:
149190
t_start = self._segment_t_start(block_index, seg_index)

neo/test/rawiotest/test_mearecrawio.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,40 @@ class TestMEArecRawIO(BaseTestRawIO, unittest.TestCase, ):
2727
'mearec/mearec_test_10s.h5'
2828
]
2929

30+
def test_not_loading_recordings(self):
31+
32+
filename = self.entities_to_test[0]
33+
filename = self.get_local_path(filename)
34+
rawio = self.rawioclass(filename=filename, load_analogsignal=False)
35+
rawio.parse_header()
36+
37+
# Test that rawio does not have a _recordings attribute
38+
self.assertFalse(hasattr(rawio, '_recordings'))
39+
40+
# Test that calling get_spike_timestamps works
41+
rawio.get_spike_timestamps()
42+
43+
# Test that caling anlogsignal chunk raises the right error
44+
with self.assertRaises(AttributeError):
45+
rawio.get_analogsignal_chunk()
46+
47+
48+
def test_not_loading_spiketrain(self):
49+
50+
filename = self.entities_to_test[0]
51+
filename = self.get_local_path(filename)
52+
rawio = self.rawioclass(filename=filename, load_spiketrains=False)
53+
rawio.parse_header()
54+
55+
# Test that rawio does not have a _spiketrains attribute
56+
self.assertFalse(hasattr(rawio, '_spiketrains'))
57+
58+
# Test that calling analogsignal chunk works
59+
rawio.get_analogsignal_chunk()
60+
61+
# Test that calling get_spike_timestamps raises an the right error
62+
with self.assertRaises(AttributeError):
63+
rawio.get_spike_timestamps()
3064

3165
if __name__ == "__main__":
3266
unittest.main()

0 commit comments

Comments
 (0)