Skip to content

Commit 43b86a7

Browse files
committed
add mearec option to load only recordings or only sorting
1 parent 2d63e18 commit 43b86a7

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

neo/rawio/mearecrawio.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,53 +36,63 @@ class MEArecRawIO(BaseRawIO):
3636
extensions = ['h5']
3737
rawmode = 'one-file'
3838

39-
def __init__(self, filename=''):
39+
def __init__(self, filename='', load_spiketrains=True, load_recordings=True):
4040
BaseRawIO.__init__(self)
4141
self.filename = filename
42-
42+
self.load_spiketrains = load_spiketrains
43+
self.load_recordings = load_recordings
44+
4345
def _source_name(self):
4446
return self.filename
4547

4648
def _parse_header(self):
4749
import MEArec as mr
50+
load = ['channel_positions']
51+
if self.load_recordings:
52+
load.append("recordings")
53+
if self.load_spiketrains:
54+
load.append("spiketrains")
55+
4856
self._recgen = mr.load_recordings(recordings=self.filename, return_h5_objects=True,
4957
check_suffix=False,
50-
load=['recordings', 'spiketrains', 'channel_positions'],
58+
load=load,
5159
load_waveforms=False)
5260
self._sampling_rate = self._recgen.info['recordings']['fs']
53-
self._recordings = self._recgen.recordings
54-
self._num_frames, self._num_channels = self._recordings.shape
5561

5662
signal_streams = np.array([('Signals', '0')], dtype=_signal_stream_dtype)
5763

5864
sig_channels = []
59-
for c in range(self._num_channels):
60-
ch_name = 'ch{}'.format(c)
61-
chan_id = str(c + 1)
62-
sr = self._sampling_rate # Hz
63-
dtype = self._recordings.dtype
64-
units = 'uV'
65-
gain = 1.
66-
offset = 0.
67-
stream_id = '0'
68-
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
65+
if self.load_recordings:
66+
self._recordings = self._recgen.recordings
67+
self._num_frames, self._num_channels = self._recordings.shape
68+
for c in range(self._num_channels):
69+
ch_name = 'ch{}'.format(c)
70+
chan_id = str(c + 1)
71+
sr = self._sampling_rate # Hz
72+
dtype = self._recordings.dtype
73+
units = 'uV'
74+
gain = 1.
75+
offset = 0.
76+
stream_id = '0'
77+
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id))
6978
sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype)
7079

7180
# creating units channels
7281
spike_channels = []
73-
self._spiketrains = self._recgen.spiketrains
74-
for c in range(len(self._spiketrains)):
75-
unit_name = 'unit{}'.format(c)
76-
unit_id = '#{}'.format(c)
77-
# if spiketrains[c].waveforms is not None:
78-
wf_units = ''
79-
wf_gain = 1.
80-
wf_offset = 0.
81-
wf_left_sweep = 0
82-
wf_sampling_rate = self._sampling_rate
83-
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
84-
wf_offset, wf_left_sweep, wf_sampling_rate))
85-
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
82+
if self.load_spiketrains:
83+
self._spiketrains = self._recgen.spiketrains
84+
for c in range(len(self._spiketrains)):
85+
unit_name = 'unit{}'.format(c)
86+
unit_id = '#{}'.format(c)
87+
# if spiketrains[c].waveforms is not None:
88+
wf_units = ''
89+
wf_gain = 1.
90+
wf_offset = 0.
91+
wf_left_sweep = 0
92+
wf_sampling_rate = self._sampling_rate
93+
spike_channels.append((unit_name, unit_id, wf_units, wf_gain,
94+
wf_offset, wf_left_sweep, wf_sampling_rate))
95+
spike_channels = np.array(spike_channels, dtype=_spike_channel_dtype)
8696

8797
event_channels = []
8898
event_channels = np.array(event_channels, dtype=_event_channel_dtype)
@@ -119,6 +129,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
119129

120130
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
121131
stream_index, channel_indexes):
132+
133+
if not self.load_recordings:
134+
raise AttributeError("Recordings not loaded. Set load_recordings=True in MEArecRawIO constructor")
135+
122136
if i_start is None:
123137
i_start = 0
124138
if i_stop is None:
@@ -127,23 +141,31 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
127141
if channel_indexes is None:
128142
channel_indexes = slice(self._num_channels)
129143
if isinstance(channel_indexes, slice):
130-
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
144+
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
131145
else:
132146
# sort channels because h5py neeeds sorted indexes
133147
if np.any(np.diff(channel_indexes) < 0):
134148
sorted_channel_indexes = np.sort(channel_indexes)
135149
sorted_idx = np.array([list(sorted_channel_indexes).index(ch)
136150
for ch in channel_indexes])
137-
raw_signals = self._recgen.recordings[i_start:i_stop, sorted_channel_indexes]
151+
raw_signals = self._recordings[i_start:i_stop, sorted_channel_indexes]
138152
raw_signals = raw_signals[:, sorted_idx]
139153
else:
140-
raw_signals = self._recgen.recordings[i_start:i_stop, channel_indexes]
154+
raw_signals = self._recordings[i_start:i_stop, channel_indexes]
141155
return raw_signals
142156

143157
def _spike_count(self, block_index, seg_index, unit_index):
158+
159+
if not self.load_spiketrains:
160+
raise AttributeError("Spiketrains not loaded. Set load_spiketrains=True in MEArecRawIO constructor")
161+
144162
return len(self._spiketrains[unit_index])
145163

146164
def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop):
165+
166+
if not self.load_spiketrains:
167+
raise AttributeError("Spiketrains not loaded. Set load_spiketrains=True in MEArecRawIO constructor")
168+
147169
spike_timestamps = self._spiketrains[unit_index].times.magnitude
148170
if t_start is None:
149171
t_start = self._segment_t_start(block_index, seg_index)

0 commit comments

Comments
 (0)