@@ -20,6 +20,18 @@ class MEArecRawIO(BaseRawIO):
2020 """
2121 Class for "reading" fake data from a MEArec file.
2222
23+ This class provides a convenient way to read data from a MEArec file.
24+
25+ Parameters
26+ ----------
27+ filename : str
28+ The filename of the MEArec file to read.
29+ load_spiketrains : bool, optional
30+ Whether or not to load spike train data. Defaults to `True`.
31+ load_analogsignal : bool, optional
32+ Whether or not to load continuous recording data. Defaults to `True`.
33+
34+
2335 Usage:
2436 >>> import neo.rawio
2537 >>> r = neo.rawio.MEArecRawIO(filename='mearec.h5')
@@ -36,52 +48,75 @@ class MEArecRawIO(BaseRawIO):
3648 extensions = ['h5' ]
3749 rawmode = 'one-file'
3850
39- def __init__ (self , filename = '' ):
51+ def __init__ (self , filename = '' , load_spiketrains = True , load_analogsignal = True ):
4052 BaseRawIO .__init__ (self )
4153 self .filename = filename
42-
54+ self .load_spiketrains = load_spiketrains
55+ self .load_analogsignal = load_analogsignal
56+
4357 def _source_name (self ):
4458 return self .filename
4559
4660 def _parse_header (self ):
61+ load = ["channel_positions" ]
62+ if self .load_analogsignal :
63+ load .append ("recordings" )
64+ if self .load_spiketrains :
65+ load .append ("spiketrains" )
66+
4767 import MEArec as mr
4868 self ._recgen = mr .load_recordings (recordings = self .filename , return_h5_objects = True ,
4969 check_suffix = False ,
50- load = [ 'recordings' , 'spiketrains' , 'channel_positions' ] ,
70+ load = load ,
5171 load_waveforms = False )
52- self ._sampling_rate = self ._recgen .info ['recordings' ]['fs' ]
53- self ._recordings = self ._recgen .recordings
54- self ._num_frames , self ._num_channels = self ._recordings .shape
55-
56- signal_streams = np .array ([('Signals' , '0' )], dtype = _signal_stream_dtype )
5772
73+ self .info_dict = deepcopy (self ._recgen .info )
74+ self .channel_positions = self ._recgen .channel_positions
75+ if self .load_analogsignal :
76+ self ._recordings = self ._recgen .recordings
77+ if self .load_spiketrains :
78+ self ._spiketrains = self ._recgen .spiketrains
79+
80+ self ._sampling_rate = self .info_dict ['recordings' ]['fs' ]
81+ self .duration_seconds = self .info_dict ["recordings" ]["duration" ]
82+ self ._num_frames = int (self ._sampling_rate * self .duration_seconds )
83+ self ._num_channels = self .channel_positions .shape [0 ]
84+ self ._dtype = self .info_dict ["recordings" ]["dtype" ]
85+
86+ signals = [('Signals' , '0' )] if self .load_analogsignal else []
87+ signal_streams = np .array (signals , dtype = _signal_stream_dtype )
88+
89+
5890 sig_channels = []
59- for c in range (self ._num_channels ):
60- ch_name = 'ch{}' .format (c )
61- chan_id = str (c + 1 )
62- sr = self ._sampling_rate # Hz
63- dtype = self ._recordings .dtype
64- units = 'uV'
65- gain = 1.
66- offset = 0.
67- stream_id = '0'
68- sig_channels .append ((ch_name , chan_id , sr , dtype , units , gain , offset , stream_id ))
91+ if self .load_analogsignal :
92+ for c in range (self ._num_channels ):
93+ ch_name = 'ch{}' .format (c )
94+ chan_id = str (c + 1 )
95+ sr = self ._sampling_rate # Hz
96+ dtype = self ._dtype
97+ units = 'uV'
98+ gain = 1.
99+ offset = 0.
100+ stream_id = '0'
101+ sig_channels .append ((ch_name , chan_id , sr , dtype , units , gain , offset , stream_id ))
102+
69103 sig_channels = np .array (sig_channels , dtype = _signal_channel_dtype )
70104
71105 # creating units channels
72106 spike_channels = []
73- self ._spiketrains = self ._recgen .spiketrains
74- for c in range (len (self ._spiketrains )):
75- unit_name = 'unit{}' .format (c )
76- unit_id = '#{}' .format (c )
77- # if spiketrains[c].waveforms is not None:
78- wf_units = ''
79- wf_gain = 1.
80- wf_offset = 0.
81- wf_left_sweep = 0
82- wf_sampling_rate = self ._sampling_rate
83- spike_channels .append ((unit_name , unit_id , wf_units , wf_gain ,
84- wf_offset , wf_left_sweep , wf_sampling_rate ))
107+ if self .load_spiketrains :
108+ for c in range (len (self ._spiketrains )):
109+ unit_name = 'unit{}' .format (c )
110+ unit_id = '#{}' .format (c )
111+ # if spiketrains[c].waveforms is not None:
112+ wf_units = ''
113+ wf_gain = 1.
114+ wf_offset = 0.
115+ wf_left_sweep = 0
116+ wf_sampling_rate = self ._sampling_rate
117+ spike_channels .append ((unit_name , unit_id , wf_units , wf_gain ,
118+ wf_offset , wf_left_sweep , wf_sampling_rate ))
119+
85120 spike_channels = np .array (spike_channels , dtype = _spike_channel_dtype )
86121
87122 event_channels = []
@@ -98,7 +133,7 @@ def _parse_header(self):
98133 self ._generate_minimal_annotations ()
99134 for block_index in range (1 ):
100135 bl_ann = self .raw_annotations ['blocks' ][block_index ]
101- bl_ann ['mearec_info' ] = deepcopy ( self ._recgen . info )
136+ bl_ann ['mearec_info' ] = self .info_dict
102137
103138 def _segment_t_start (self , block_index , seg_index ):
104139 all_starts = [[0. ]]
@@ -119,6 +154,10 @@ def _get_signal_t_start(self, block_index, seg_index, stream_index):
119154
120155 def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
121156 stream_index , channel_indexes ):
157+
158+ if not self .load_analogsignal :
159+ raise AttributeError ("Recordings not loaded. Set load_analogsignal=True in MEArecRawIO constructor" )
160+
122161 if i_start is None :
123162 i_start = 0
124163 if i_stop is None :
@@ -127,23 +166,25 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
127166 if channel_indexes is None :
128167 channel_indexes = slice (self ._num_channels )
129168 if isinstance (channel_indexes , slice ):
130- raw_signals = self ._recgen . recordings [i_start :i_stop , channel_indexes ]
169+ raw_signals = self ._recordings [i_start :i_stop , channel_indexes ]
131170 else :
132171 # sort channels because h5py neeeds sorted indexes
133172 if np .any (np .diff (channel_indexes ) < 0 ):
134173 sorted_channel_indexes = np .sort (channel_indexes )
135174 sorted_idx = np .array ([list (sorted_channel_indexes ).index (ch )
136175 for ch in channel_indexes ])
137- raw_signals = self ._recgen . recordings [i_start :i_stop , sorted_channel_indexes ]
176+ raw_signals = self ._recordings [i_start :i_stop , sorted_channel_indexes ]
138177 raw_signals = raw_signals [:, sorted_idx ]
139178 else :
140- raw_signals = self ._recgen . recordings [i_start :i_stop , channel_indexes ]
179+ raw_signals = self ._recordings [i_start :i_stop , channel_indexes ]
141180 return raw_signals
142181
143182 def _spike_count (self , block_index , seg_index , unit_index ):
183+
144184 return len (self ._spiketrains [unit_index ])
145185
146186 def _get_spike_timestamps (self , block_index , seg_index , unit_index , t_start , t_stop ):
187+
147188 spike_timestamps = self ._spiketrains [unit_index ].times .magnitude
148189 if t_start is None :
149190 t_start = self ._segment_t_start (block_index , seg_index )
0 commit comments