@@ -36,53 +36,63 @@ class MEArecRawIO(BaseRawIO):
3636 extensions = ['h5' ]
3737 rawmode = 'one-file'
3838
39- def __init__ (self , filename = '' ):
39+ def __init__ (self , filename = '' , load_spiketrains = True , load_recordings = True ):
4040 BaseRawIO .__init__ (self )
4141 self .filename = filename
42-
42+ self .load_spiketrains = load_spiketrains
43+ self .load_recordings = load_recordings
44+
4345 def _source_name (self ):
4446 return self .filename
4547
4648 def _parse_header (self ):
4749 import MEArec as mr
50+ load = ['channel_positions' ]
51+ if self .load_recordings :
52+ load .append ("recordings" )
53+ if self .load_spiketrains :
54+ load .append ("spiketrains" )
55+
4856 self ._recgen = mr .load_recordings (recordings = self .filename , return_h5_objects = True ,
4957 check_suffix = False ,
50- load = [ 'recordings' , 'spiketrains' , 'channel_positions' ] ,
58+ load = load ,
5159 load_waveforms = False )
5260 self ._sampling_rate = self ._recgen .info ['recordings' ]['fs' ]
53- self ._recordings = self ._recgen .recordings
54- self ._num_frames , self ._num_channels = self ._recordings .shape
5561
5662 signal_streams = np .array ([('Signals' , '0' )], dtype = _signal_stream_dtype )
5763
5864 sig_channels = []
59- for c in range (self ._num_channels ):
60- ch_name = 'ch{}' .format (c )
61- chan_id = str (c + 1 )
62- sr = self ._sampling_rate # Hz
63- dtype = self ._recordings .dtype
64- units = 'uV'
65- gain = 1.
66- offset = 0.
67- stream_id = '0'
68- sig_channels .append ((ch_name , chan_id , sr , dtype , units , gain , offset , stream_id ))
65+ if self .load_recordings :
66+ self ._recordings = self ._recgen .recordings
67+ self ._num_frames , self ._num_channels = self ._recordings .shape
68+ for c in range (self ._num_channels ):
69+ ch_name = 'ch{}' .format (c )
70+ chan_id = str (c + 1 )
71+ sr = self ._sampling_rate # Hz
72+ dtype = self ._recordings .dtype
73+ units = 'uV'
74+ gain = 1.
75+ offset = 0.
76+ stream_id = '0'
77+ sig_channels .append ((ch_name , chan_id , sr , dtype , units , gain , offset , stream_id ))
6978 sig_channels = np .array (sig_channels , dtype = _signal_channel_dtype )
7079
7180 # creating units channels
7281 spike_channels = []
73- self ._spiketrains = self ._recgen .spiketrains
74- for c in range (len (self ._spiketrains )):
75- unit_name = 'unit{}' .format (c )
76- unit_id = '#{}' .format (c )
77- # if spiketrains[c].waveforms is not None:
78- wf_units = ''
79- wf_gain = 1.
80- wf_offset = 0.
81- wf_left_sweep = 0
82- wf_sampling_rate = self ._sampling_rate
83- spike_channels .append ((unit_name , unit_id , wf_units , wf_gain ,
84- wf_offset , wf_left_sweep , wf_sampling_rate ))
85- spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
82+ if self .load_spiketrains :
83+ self ._spiketrains = self ._recgen .spiketrains
84+ for c in range (len (self ._spiketrains )):
85+ unit_name = 'unit{}' .format (c )
86+ unit_id = '#{}' .format (c )
87+ # if spiketrains[c].waveforms is not None:
88+ wf_units = ''
89+ wf_gain = 1.
90+ wf_offset = 0.
91+ wf_left_sweep = 0
92+ wf_sampling_rate = self ._sampling_rate
93+ spike_channels .append ((unit_name , unit_id , wf_units , wf_gain ,
94+ wf_offset , wf_left_sweep , wf_sampling_rate ))
95+ spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
8696
8797 event_channels = []
8898 event_channels = np .array (event_channels , dtype = _event_channel_dtype )
@@ -119,6 +129,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
119129
120130 def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
121131 stream_index , channel_indexes ):
132+
133+ if not self .load_recordings :
134+ raise AttributeError ("Recordings not loaded. Set load_recordings=True in MEArecRawIO constructor" )
135+
122136 if i_start is None :
123137 i_start = 0
124138 if i_stop is None :
@@ -127,23 +141,31 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
127141 if channel_indexes is None :
128142 channel_indexes = slice (self ._num_channels )
129143 if isinstance (channel_indexes , slice ):
130- raw_signals = self ._recgen . recordings [i_start :i_stop , channel_indexes ]
144+ raw_signals = self ._recordings [i_start :i_stop , channel_indexes ]
131145 else :
132146 # sort channels because h5py neeeds sorted indexes
133147 if np .any (np .diff (channel_indexes ) < 0 ):
134148 sorted_channel_indexes = np .sort (channel_indexes )
135149 sorted_idx = np .array ([list (sorted_channel_indexes ).index (ch )
136150 for ch in channel_indexes ])
137- raw_signals = self ._recgen . recordings [i_start :i_stop , sorted_channel_indexes ]
151+ raw_signals = self ._recordings [i_start :i_stop , sorted_channel_indexes ]
138152 raw_signals = raw_signals [:, sorted_idx ]
139153 else :
140- raw_signals = self ._recgen . recordings [i_start :i_stop , channel_indexes ]
154+ raw_signals = self ._recordings [i_start :i_stop , channel_indexes ]
141155 return raw_signals
142156
143157 def _spike_count (self , block_index , seg_index , unit_index ):
158+
159+ if not self .load_spiketrains :
160+ raise AttributeError ("Spiketrains not loaded. Set load_spiketrains=True in MEArecRawIO constructor" )
161+
144162 return len (self ._spiketrains [unit_index ])
145163
146164 def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
165+
166+ if not self .load_spiketrains :
167+ raise AttributeError ("Spiketrains not loaded. Set load_spiketrains=True in MEArecRawIO constructor" )
168+
147169 spike_timestamps = self ._spiketrains [unit_index ].times .magnitude
148170 if t_start is None :
149171 t_start = self ._segment_t_start (block_index , seg_index )
0 commit comments