diff --git a/neo/core/spiketrain.py b/neo/core/spiketrain.py index 7256511bf..3d23cb0fc 100644 --- a/neo/core/spiketrain.py +++ b/neo/core/spiketrain.py @@ -76,7 +76,7 @@ def _check_time_in_range(value, t_start, t_stop, view=False): def _check_waveform_dimensions(spiketrain): ''' Verify that waveform is compliant with the waveform definition as - quantity array 3D (spike, channel_index, time) + quantity array 3D (time, spike, channel_index) ''' if not spiketrain.size: @@ -87,10 +87,10 @@ def _check_waveform_dimensions(spiketrain): if (waveforms is None) or (not waveforms.size): return - if waveforms.shape[0] != len(spiketrain): + if waveforms.shape[1] != len(spiketrain): raise ValueError("Spiketrain length (%s) does not match to number of " "waveforms present (%s)" % (len(spiketrain), - waveforms.shape[0])) + waveforms.shape[1])) def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None, @@ -161,7 +161,7 @@ class SpikeTrain(BaseNeo, pq.Quantity): :class:`SpikeTrain` began. This will be converted to the same units as :attr:`times`. Default: 0.0 seconds. - :waveforms: (quantity array 3D (spike, channel_index, time)) + :waveforms: (quantity array 3D (time, spike, channel_index)) The waveforms of each spike. :sampling_rate: (quantity scalar) Number of samples per unit time for the waveforms. @@ -184,7 +184,7 @@ class SpikeTrain(BaseNeo, pq.Quantity): read-only. (:attr:`t_stop` - :attr:`t_start`) :spike_duration: (quantity scalar) Duration of a waveform, read-only. - (:attr:`waveform`.shape[2] * :attr:`sampling_period`) + (:attr:`waveform`.shape[0] * :attr:`sampling_period`) :right_sweep: (quantity scalar) Time from the trigger times of the spikes to the end of the waveforms, read-only. (:attr:`left_sweep` + :attr:`spike_duration`) @@ -219,9 +219,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True, This is called whenever a new :class:`SpikeTrain` is created from the constructor, but not when slicing. ''' - if len(times) != 0 and waveforms is not None and len(times) != \ - waveforms.shape[0]: - # len(times)!=0 has been used to workaround a bug occuring during neo import + if len(times) != 0 and waveforms is not None and len(times) != waveforms.shape[1]: raise ValueError( "the number of waveforms should be equal to the number of spikes") @@ -435,7 +433,7 @@ def sort(self): # sort the waveforms by the times sort_indices = np.argsort(self) if self.waveforms is not None and self.waveforms.any(): - self.waveforms = self.waveforms[sort_indices] + self.waveforms = self.waveforms[:, sort_indices, :] # now sort the times # We have sorted twice, but `self = self[sort_indices]` introduces @@ -492,7 +490,7 @@ def __getitem__(self, i): ''' obj = super(SpikeTrain, self).__getitem__(i) if hasattr(obj, 'waveforms') and obj.waveforms is not None: - obj.waveforms = obj.waveforms.__getitem__(i) + obj.waveforms = obj.waveforms.__getitem__([slice(None), i, slice(None)]) return obj def __setitem__(self, i, value): @@ -570,7 +568,7 @@ def time_slice(self, t_start, t_stop): new_st.t_start = max(_t_start, self.t_start) new_st.t_stop = min(_t_stop, self.t_stop) if self.waveforms is not None: - new_st.waveforms = self.waveforms[indices] + new_st.waveforms = self.waveforms[:, indices, :] return new_st @@ -627,8 +625,8 @@ def merge(self, other): sampling_rate=self.sampling_rate, left_sweep=self.left_sweep, **kwargs) if all(wfs): - wfs_stack = np.vstack((self.waveforms, other.waveforms)) - wfs_stack = wfs_stack[sorting] + wfs_stack = np.concatenate((self.waveforms, other.waveforms), axis=1) + wfs_stack = wfs_stack[:, sorting, :] train.waveforms = wfs_stack train.segment = self.segment if train.segment is not None: @@ -661,11 +659,11 @@ def spike_duration(self): ''' Duration of a waveform. - (:attr:`waveform`.shape[2] * :attr:`sampling_period`) + (:attr:`waveform`.shape[0] * :attr:`sampling_period`) ''' if self.waveforms is None or self.sampling_rate is None: return None - return self.waveforms.shape[2] / self.sampling_rate + return self.waveforms.shape[0] / self.sampling_rate @property def sampling_period(self): diff --git a/neo/io/brainwaresrcio.py b/neo/io/brainwaresrcio.py index 08d21db44..cda561a48 100755 --- a/neo/io/brainwaresrcio.py +++ b/neo/io/brainwaresrcio.py @@ -619,6 +619,7 @@ def _combine_spiketrains(self, spiketrains): # get the maximum time t_stop = times[-1] * 2. + waveforms = np.moveaxis(waveforms, 2, 0) waveforms = pq.Quantity(waveforms, units=pq.mV, copy=False) train = SpikeTrain(times=times, copy=False, diff --git a/neo/rawio/baserawio.py b/neo/rawio/baserawio.py index b4bf8b479..6d1e125f2 100644 --- a/neo/rawio/baserawio.py +++ b/neo/rawio/baserawio.py @@ -40,7 +40,7 @@ """ -#from __future__ import unicode_literals, print_function, division, absolute_import +# from __future__ import unicode_literals, print_function, division, absolute_import from __future__ import print_function, division, absolute_import import logging @@ -52,16 +52,15 @@ try: import joblib + HAVE_JOBLIB = True except ImportError: HAVE_JOBLIB = False - possible_raw_modes = ['one-file', 'multi-file', 'one-dir', ] # 'multi-dir', 'url', 'other' error_header = 'Header is not read yet, do parse_header() first' - _signal_channel_dtype = [ ('name', 'U64'), ('id', 'int64'), @@ -73,8 +72,7 @@ ('group_id', 'int64'), ] -_common_sig_characteristics = ['sampling_rate', 'dtype', 'group_id'] - +_common_sig_characteristics = ['sampling_rate', 'dtype', 'group_id'] _unit_channel_dtype = [ ('name', 'U64'), @@ -87,7 +85,6 @@ ('wf_sampling_rate', 'float64'), ] - _event_channel_dtype = [ ('name', 'U64'), ('id', 'U64'), @@ -107,7 +104,7 @@ class BaseRawIO(object): rawmode = None # one key in possible_raw_modes - def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs): + def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs): """ When rawmode=='one-file' kargs MUST contains 'filename' the filename @@ -191,9 +188,12 @@ def _generate_minimal_annotations(self): Usage: raw_annotations['blocks'][block_index] = { 'nickname' : 'super block', 'segments' : ...} raw_annotations['blocks'][block_index] = { 'nickname' : 'super block', 'segments' : ...} - raw_annotations['blocks'][block_index]['segments'][seg_index]['signals'][channel_index] = {'nickname': 'super channel'} - raw_annotations['blocks'][block_index]['segments'][seg_index]['units'][unit_index] = {'nickname': 'super neuron'} - raw_annotations['blocks'][block_index]['segments'][seg_index]['events'][ev_chan] = {'nickname': 'super trigger'} + raw_annotations['blocks'][block_index]['segments'][seg_index]['signals'][channel_index] = \ + {'nickname': 'super channel'} + raw_annotations['blocks'][block_index]['segments'][seg_index]['units'][unit_index] = \ + {'nickname': 'super neuron'} + raw_annotations['blocks'][block_index]['segments'][seg_index]['events'][ev_chan] = \ + {'nickname': 'super trigger'} Theses annotations will be used at the neo.io API directly in objects. @@ -261,7 +261,7 @@ def _generate_minimal_annotations(self): self.raw_annotations = a - def _raw_annotate(self, obj_name, chan_index=0, block_index=0, seg_index=0, **kargs): + def _raw_annotate(self, obj_name, chan_index=0, block_index=0, seg_index=0, **kargs): """ Annotate a object in the list/dict tree annotations. """ @@ -284,7 +284,7 @@ def _repr_annotations(self): bl_a = self.raw_annotations['blocks'][block_index] txt += '*Block {}\n'.format(block_index) for k, v in bl_a.items(): - if k in ('segments', ): + if k in ('segments',): continue txt += ' -{}: {}\n'.format(k, v) for seg_index in range(self.segment_count(block_index)): @@ -363,13 +363,13 @@ def _group_signal_channel_characteristics(self): If all channels have the same characteristics them `get_analogsignal_chunk` can be call wihtout restriction. If not then **channel_indexes** must be specified - in `get_analogsignal_chunk` and only channels with same + in `get_analogsignal_chunk` and only channels with same caracteristics can be read at the same time. - This is usefull for some IO than + This is usefull for some IO than have internally several signals channels familly. - For many RawIO all channels have the same + For many RawIO all channels have the same sampling_rate/size/t_start. In that cases, internal flag **self._several_channel_groups will be set to False, so `get_analogsignal_chunk(..)` won't suffer in performance. @@ -393,17 +393,17 @@ def _check_common_characteristics(self, channel_indexes): """ Usefull for few IOs (TdtrawIO, NeuroExplorerRawIO, ...). - Check is a set a signal channel_indexes share common + Check is a set a signal channel_indexes share common characteristics (**sampling_rate/t_start/size**) Usefull only when RawIO propose differents channels groups with differents sampling_rate for instance. """ - #~ print('_check_common_characteristics', channel_indexes) + # ~ print('_check_common_characteristics', channel_indexes) - assert channel_indexes is not None,\ + assert channel_indexes is not None, \ 'You must specify channel_indexes' characteristics = self.header['signal_channels'][_common_sig_characteristics] - #~ print(characteristics[channel_indexes]) + # ~ print(characteristics[channel_indexes]) assert np.unique(characteristics[channel_indexes]).size == 1, \ 'This channel set have differents characteristics' @@ -418,7 +418,7 @@ def get_group_channel_indexes(self): unique_characteristics = np.unique(characteristics) channel_indexes_list = [] for e in unique_characteristics: - channel_indexes, = np.nonzero(characteristics == e) + channel_indexes, = np.nonzero(characteristics == e) channel_indexes_list.append(channel_indexes) return channel_indexes_list else: @@ -430,7 +430,7 @@ def channel_name_to_index(self, channel_names): Based on self.header['signal_channels'] """ ch = self.header['signal_channels'] - channel_indexes, = np.nonzero(np.in1d(ch['name'], channel_names)) + channel_indexes, = np.nonzero(np.in1d(ch['name'], channel_names)) assert len(channel_indexes) == len(channel_names), 'not match' return channel_indexes @@ -440,7 +440,7 @@ def channel_id_to_index(self, channel_ids): Based on self.header['signal_channels'] """ ch = self.header['signal_channels'] - channel_indexes, = np.nonzero(np.in1d(ch['id'], channel_ids)) + channel_indexes, = np.nonzero(np.in1d(ch['id'], channel_ids)) assert len(channel_indexes) == len(channel_ids), 'not match' return channel_indexes @@ -486,11 +486,11 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto self._check_common_characteristics(channel_indexes) raw_chunk = self._get_analogsignal_chunk( - block_index, seg_index, i_start, i_stop, channel_indexes) + block_index, seg_index, i_start, i_stop, channel_indexes) return raw_chunk - def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', + def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', channel_indexes=None, channel_names=None, channel_ids=None): channel_indexes = self._get_channel_indexes(channel_indexes, channel_names, channel_ids) @@ -510,10 +510,10 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', return float_signal # spiketrain and unit zone - def spike_count(self, block_index=0, seg_index=0, unit_index=0): + def spike_count(self, block_index=0, seg_index=0, unit_index=0): return self._spike_count(block_index, seg_index, unit_index) - def get_spike_timestamps(self, block_index=0, seg_index=0, unit_index=0, + def get_spike_timestamps(self, block_index=0, seg_index=0, unit_index=0, t_start=None, t_stop=None): """ The timestamp is as close to the format itself. Sometimes float/int32/int64. @@ -533,7 +533,7 @@ def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'): return self._rescale_spike_timestamp(spike_timestamps, dtype) # spiketrain waveform zone - def get_spike_raw_waveforms(self, block_index=0, seg_index=0, unit_index=0, + def get_spike_raw_waveforms(self, block_index=0, seg_index=0, unit_index=0, t_start=None, t_stop=None): wf = self._get_spike_raw_waveforms(block_index, seg_index, unit_index, t_start, t_stop) return wf @@ -552,10 +552,10 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32', unit_index= return float_waveforms # event and epoch zone - def event_count(self, block_index=0, seg_index=0, event_channel_index=0): + def event_count(self, block_index=0, seg_index=0, event_channel_index=0): return self._event_count(block_index, seg_index, event_channel_index) - def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0, + def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0, t_start=None, t_stop=None): """ The timestamp is as close to the format itself. Sometimes float/int32/int64. @@ -592,7 +592,7 @@ def setup_cache(self, cache_path, **init_kargs): elif self.rawmode == 'one-dir': ressource_name = self.dirname else: - raise(NotImlementedError) + raise (NotImlementedError) if cache_path == 'home': if sys.platform.startswith('win'): @@ -608,11 +608,12 @@ def setup_cache(self, cache_path, **init_kargs): elif cache_path == 'same_as_resource': dirname = os.path.dirname(ressource_name) else: - assert os.path.exists(cache_path),\ + assert os.path.exists(cache_path), \ 'cache_path do not exists use "home" or "same_as_file" to make this auto' # the hash of the ressource (dir of file) is done with filename+datetime - # TODO make something more sofisticated when rawmode='one-dir' that use all filename and datetime + # TODO make something more sofisticated when rawmode='one-dir', + # that use all filename and datetime d = dict(ressource_name=ressource_name, mtime=os.path.getmtime(ressource_name)) hash = joblib.hash(d, hash_name='md5') @@ -642,56 +643,56 @@ def dump_cache(self): # Functions to be implement in IO below here def _parse_header(self): - raise(NotImplementedError) + raise (NotImplementedError) # must call # self._generate_empty_annotations() def _source_name(self): - raise(NotImplementedError) + raise (NotImplementedError) def _segment_t_start(self, block_index, seg_index): - raise(NotImplementedError) + raise (NotImplementedError) def _segment_t_stop(self, block_index, seg_index): - raise(NotImplementedError) + raise (NotImplementedError) ### # signal and channel zone def _get_signal_size(self, block_index, seg_index, channel_indexes): - raise(NotImplementedError) + raise (NotImplementedError) def _get_signal_t_start(self, block_index, seg_index, channel_indexes): - raise(NotImplementedError) + raise (NotImplementedError) - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): - raise(NotImplementedError) + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + raise (NotImplementedError) ### # spiketrain and unit zone - def _spike_count(self, block_index, seg_index, unit_index): - raise(NotImplementedError) + def _spike_count(self, block_index, seg_index, unit_index): + raise (NotImplementedError) - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): - raise(NotImplementedError) + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + raise (NotImplementedError) def _rescale_spike_timestamp(self, spike_timestamps, dtype): - raise(NotImplementedError) + raise (NotImplementedError) ### # spike waveforms zone def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop): - raise(NotImplementedError) + raise (NotImplementedError) ### # event and epoch zone def _event_count(self, block_index, seg_index, event_channel_index): - raise(NotImplementedError) + raise (NotImplementedError) - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): - raise(NotImplementedError) + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + raise (NotImplementedError) def _rescale_event_timestamp(self, event_timestamps, dtype): - raise(NotImplementedError) + raise (NotImplementedError) def _rescale_epoch_duration(self, raw_duration, dtype): - raise(NotImplementedError) + raise (NotImplementedError) diff --git a/neo/rawio/blackrockrawio.py b/neo/rawio/blackrockrawio.py index 159ddd7be..926e0ff66 100644 --- a/neo/rawio/blackrockrawio.py +++ b/neo/rawio/blackrockrawio.py @@ -66,7 +66,6 @@ import numpy as np import quantities as pq - from .baserawio import (BaseRawIO, _signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype) @@ -116,7 +115,7 @@ class BlackrockRawIO(BaseRawIO): >>> reader = BlackrockRawIO(filename='FileSpec2.3001', nsx_to_load=5) >>> reader.parse_header() - Inspect a set of file consisting of files FileSpec2.3001.ns5 and + Inspect a set of file consisting of files FileSpec2.3001.ns5 and FileSpec2.3001.nev >>> print(reader) @@ -279,7 +278,8 @@ def _parse_header(self): self.__nsx_basic_header[nsx_nb], self.__nsx_ext_header[nsx_nb] = \ self.__nsx_header_reader[spec](nsx_nb) - # Read nsx data header(s) for nsxdef get_analogsignal_shape(self, block_index, seg_index): + # Read nsx data header(s) for nsxdef get_analogsignal_shape + # (self, block_index, seg_index): self.__nsx_data_header[nsx_nb] = self.__nsx_dataheader_reader[spec](nsx_nb) # We can load only one for one class instance @@ -320,8 +320,8 @@ def _parse_header(self): sig_dtype = 'int16' # max_analog_val/min_analog_val/max_digital_val/min_analog_val are int16!!!!! # dangarous situation so cast to float everyone - gain = (float(chan['max_analog_val']) - float(chan['min_analog_val'])) /\ - (float(chan['max_digital_val']) - float(chan['min_digital_val'])) + gain = (float(chan['max_analog_val']) - float(chan['min_analog_val'])) / \ + (float(chan['max_digital_val']) - float(chan['min_digital_val'])) offset = -float(chan['min_digital_val']) * gain + float(chan['min_analog_val']) group_id = 0 sig_channels.append((ch_name, ch_id, sig_sampling_rate, sig_dtype, @@ -336,7 +336,7 @@ def _parse_header(self): t_start = 0. else: t_start = self.__nsx_data_header[self.nsx_to_load][data_bl]['timestamp'] / \ - sig_sampling_rate + sig_sampling_rate t_stop = t_start + length / sig_sampling_rate max_nev_time = 0 for k, data in self.nev_data.items(): @@ -531,7 +531,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, chann sig_chunk = memmap_data[i_start:i_stop, channel_indexes] return sig_chunk - def _spike_count(self, block_index, seg_index, unit_index): + def _spike_count(self, block_index, seg_index, unit_index): channel_id, unit_id = self.internal_unit_ids[unit_index] all_spikes = self.nev_data['Spikes'] @@ -542,12 +542,12 @@ def _spike_count(self, block_index, seg_index, unit_index): else: # must clip in time time range timestamp = all_spikes[mask]['timestamp'] - sl = self._get_timestamp_slice(timestamp, seg_index, None, None) + sl = self._get_timestamp_slice(timestamp, seg_index, None, None) timestamp = timestamp[sl] nb = timestamp.size return nb - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): channel_id, unit_id = self.internal_unit_ids[unit_index] all_spikes = self.nev_data['Spikes'] @@ -557,7 +557,7 @@ def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_ unit_spikes = all_spikes[mask] timestamp = unit_spikes['timestamp'] - sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) + sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) timestamp = timestamp[sl] return timestamp @@ -604,8 +604,9 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, waveforms = waveforms.reshape(int(unit_spikes.size), 1, int(wf_size)) timestamp = unit_spikes['timestamp'] - sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) + sl = self._get_timestamp_slice(timestamp, seg_index, t_start, t_stop) waveforms = waveforms[sl] + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -619,7 +620,7 @@ def _event_count(self, block_index, seg_index, event_channel_index): else: # must clip in time time range timestamp = events_data[ev_dict['mask']]['timestamp'] - sl = self._get_timestamp_slice(timestamp, seg_index, None, None) + sl = self._get_timestamp_slice(timestamp, seg_index, None, None) timestamp = timestamp[sl] nb = timestamp.size return nb @@ -841,8 +842,8 @@ def __read_nsx_dataheader_variant_b( # data size = number of data points * (2bytes * number of channels) # use of `int` avoids overflow problem - data_size = int(dh['nb_data_points']) * \ - int(self.__nsx_basic_header[nsx_nb]['channel_count']) * 2 + channel_count = int(self.__nsx_basic_header[nsx_nb]['channel_count']) + data_size = int(dh['nb_data_points']) * channel_count * 2 # define new offset (to possible next data block) offset = data_header[index]['offset_to_data_block'] + data_size @@ -1261,7 +1262,7 @@ def __nev_data_types(self, data_size): ('video_frame_nb', 'uint32'), ('video_elapsed_time', 'uint32'), ('video_source_id', 'uint32'), - ('unused', 'int8', (data_size - 20, ))]}, + ('unused', 'int8', (data_size - 20,))]}, 'TrackingEvents': { 'a': [ ('timestamp', 'uint32'), @@ -1270,13 +1271,13 @@ def __nev_data_types(self, data_size): ('node_id', 'uint16'), ('node_count', 'uint16'), ('point_count', 'uint16'), - ('tracking_points', 'uint16', ((data_size - 14) // 2, ))]}, + ('tracking_points', 'uint16', ((data_size - 14) // 2,))]}, 'ButtonTrigger': { 'a': [ ('timestamp', 'uint32'), ('packet_id', 'uint16'), ('trigger_type', 'uint16'), - ('unused', 'int8', (data_size - 8, ))]}, + ('unused', 'int8', (data_size - 8,))]}, 'ConfigEvent': { 'a': [ ('timestamp', 'uint32'), @@ -1503,8 +1504,8 @@ def __get_nsx_param_variant_a(self, nsx_nb): filename = '.'.join([self._filenames['nsx'], 'ns%i' % nsx_nb]) bytes_in_headers = self.__nsx_basic_header[nsx_nb].dtype.itemsize + \ - self.__nsx_ext_header[nsx_nb].dtype.itemsize * \ - self.__nsx_basic_header[nsx_nb]['channel_count'] + self.__nsx_ext_header[nsx_nb].dtype.itemsize * \ + self.__nsx_basic_header[nsx_nb]['channel_count'] nsx_parameters = { 'nb_data_points': int( @@ -1527,7 +1528,8 @@ def __get_nsx_param_variant_a(self, nsx_nb): 'time_unit': pq.CompoundUnit("1.0/{0}*s".format( 30000 / self.__nsx_basic_header[nsx_nb]['period']))} - return nsx_parameters # Returns complete dictionary because then it does not need to be called so often + # Returns complete dictionary because then it does not need to be called so often + return nsx_parameters def __get_nsx_param_variant_b(self, param_name, nsx_nb): """ diff --git a/neo/rawio/examplerawio.py b/neo/rawio/examplerawio.py index f662f27e2..dd7921b6b 100644 --- a/neo/rawio/examplerawio.py +++ b/neo/rawio/examplerawio.py @@ -13,18 +13,18 @@ * code hard! The main difficulty **is _parse_header()**. In short you have a create a mandatory dict than contains channel informations:: - + self.header = {} self.header['nb_block'] = 2 self.header['nb_segment'] = [2, 3] self.header['signal_channels'] = sig_channels self.header['unit_channels'] = unit_channels - self.header['event_channels'] = event_channels - + self.header['event_channels'] = event_channels + 2. Step 2: RawIO test: * create a file in neo/rawio/tests with the same name with "test_" prefix * copy paste neo/rawio/tests/test_examplerawio.py and do the same - + 3. Step 3 : Create the neo.io class with the wrapper * Create a file in neo/io/ that endith with "io.py" * Create a that hinerits bot yrou RawIO class and BaseFromRaw class @@ -75,7 +75,7 @@ class ExampleRawIO(BaseRawIO): >>> print(r) >>> raw_chunk = r.get_analogsignal_chunk(block_index=0, seg_index=0, i_start=0, i_stop=1024, channel_names=channel_names) - >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64', + >>> float_chunk = reader.rescale_signal_raw_to_float(raw_chunk, dtype='float64', channel_indexes=[0, 3, 6]) >>> spike_timestamp = reader.spike_timestamps(unit_index=0, t_start=None, t_stop=None) >>> spike_times = reader.rescale_spike_timestamp(spike_timestamp, 'float64') @@ -311,7 +311,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, np.random.seed(2205) # a magic number (my birthday) waveforms = np.random.randint(low=-2**4, high=2**4, size=20 * 50, dtype='int16') - waveforms = waveforms.reshape(20, 1, 50) + waveforms = waveforms.reshape(50, 20, 1) return waveforms def _event_count(self, block_index, seg_index, event_channel_index): @@ -328,7 +328,7 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_ # the main difference between spike channel and event channel # is that for here we have 3 numpy array timestamp, durations, labels # durations must be None for 'event' - # label must a dtype ='U' + # label must a dtype ='U' # in our IO event are directly coded in seconds seg_t_start = self._segment_t_start(block_index, seg_index) diff --git a/neo/rawio/neuralynxrawio.py b/neo/rawio/neuralynxrawio.py index 691787ab4..1240ef476 100644 --- a/neo/rawio/neuralynxrawio.py +++ b/neo/rawio/neuralynxrawio.py @@ -31,9 +31,8 @@ import datetime from collections import OrderedDict - BLOCK_SIZE = 512 # nb sample per signal block -HEADER_SIZE = 2**14 # file have a txt header of 16kB +HEADER_SIZE = 2 ** 14 # file have a txt header of 16kB class NeuralynxRawIO(BaseRawIO): @@ -104,10 +103,12 @@ def _parse_header(self): sig_channels.append((chan_name, chan_id, info['sampling_rate'], 'int16', units, gain, offset, group_id)) self.ncs_filenames[chan_id] = filename - keys = ['DspFilterDelay_µs', 'recording_opened', 'FileType', 'DspDelayCompensation', 'recording_closed', - 'DspLowCutFilterType', 'HardwareSubSystemName', 'DspLowCutNumTaps', 'DSPLowCutFilterEnabled', - 'HardwareSubSystemType', 'DspHighCutNumTaps', 'ADMaxValue', 'DspLowCutFrequency', - 'DSPHighCutFilterEnabled', 'RecordSize', 'InputRange', 'DspHighCutFrequency', + keys = ['DspFilterDelay_µs', 'recording_opened', 'FileType', + 'DspDelayCompensation', 'recording_closed', + 'DspLowCutFilterType', 'HardwareSubSystemName', 'DspLowCutNumTaps', + 'DSPLowCutFilterEnabled', 'HardwareSubSystemType', 'DspHighCutNumTaps', + 'ADMaxValue', 'DspLowCutFrequency', 'DSPHighCutFilterEnabled', + 'RecordSize', 'InputRange', 'DspHighCutFrequency', 'input_inverted', 'NumADChannels', 'DspHighCutFilterType', ] d = {k: info[k] for k in keys if k in info} signal_annotations.append(d) @@ -116,7 +117,8 @@ def _parse_header(self): # nse and ntt are pretty similar execept for the wavform shape # a file can contain several unit_id (so several unit channel) assert chan_id not in self.nse_ntt_filenames - self.nse_ntt_filenames[chan_id] = filename, 'Several nse or ntt files have the same unit_id!!!' + self.nse_ntt_filenames[ + chan_id] = filename, 'Several nse or ntt files have the same unit_id!!!' dtype = get_nse_or_ntt_dtype(info, ext) data = np.memmap(filename, dtype=dtype, mode='r', offset=HEADER_SIZE) @@ -136,8 +138,9 @@ def _parse_header(self): wf_offset = 0. wf_left_sweep = -1 # DONT KNOWN wf_sampling_rate = info['sampling_rate'] - unit_channels.append((unit_name, '{}'.format(unit_id), wf_units, wf_gain, wf_offset, - wf_left_sweep, wf_sampling_rate)) + unit_channels.append( + (unit_name, '{}'.format(unit_id), wf_units, wf_gain, wf_offset, + wf_left_sweep, wf_sampling_rate)) unit_annotations.append(dict(file_origin=filename)) elif ext == 'nev': @@ -233,10 +236,10 @@ def _parse_header(self): ev_ann = seg_annotations['events'][c] ev_ann['file_origin'] = self.nev_filenames[chan_id] - #~ ev_ann['marker_id'] = - #~ ev_ann['nttl'] = - #~ ev_ann['digital_marker'] = - #~ ev_ann['analog_marker'] = + # ~ ev_ann['marker_id'] = + # ~ ev_ann['nttl'] = + # ~ ev_ann['digital_marker'] = + # ~ ev_ann['analog_marker'] = def _segment_t_start(self, block_index, seg_index): return self._seg_t_starts[seg_index] - self.global_t_start @@ -250,7 +253,7 @@ def _get_signal_size(self, block_index, seg_index, channel_indexes): def _get_signal_t_start(self, block_index, seg_index, channel_indexes): return self._sigs_t_start[seg_index] - self.global_t_start - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): if i_start is None: i_start = 0 if i_stop is None: @@ -273,7 +276,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, chan return sigs_chunk - def _spike_count(self, block_index, seg_index, unit_index): + def _spike_count(self, block_index, seg_index, unit_index): chan_id, unit_id = self.internal_unit_ids[unit_index] data = self._spike_memmap[chan_id] ts = data['timestamp'] @@ -284,7 +287,7 @@ def _spike_count(self, block_index, seg_index, unit_index): nb_spike = int(data[keep].size) return nb_spike - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): chan_id, unit_id = self.internal_unit_ids[unit_index] data = self._spike_memmap[chan_id] ts = data['timestamp'] @@ -325,6 +328,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, else: # case for ntt change (n, 32, 4) to (n, 4, 32) waveforms = wfs.swapaxes(1, 2) + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -334,12 +338,12 @@ def _event_count(self, block_index, seg_index, event_channel_index): data = self._nev_memmap[chan_id] ts0, ts1 = self._timestamp_limits[seg_index] ts = data['timestamp'] - keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) &\ - (data['ttl_input'] == ttl_input) + keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) & \ + (data['ttl_input'] == ttl_input) nb_event = int(data[keep].size) return nb_event - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): event_id, ttl_input = self.internal_event_ids[event_channel_index] chan_id = self.header['event_channels'][event_channel_index]['id'] data = self._nev_memmap[chan_id] @@ -351,8 +355,8 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_ ts1 = int((t_stop + self.global_t_start) * 1e6) ts = data['timestamp'] - keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) &\ - (data['ttl_input'] == ttl_input) + keep = (ts >= ts0) & (ts <= ts1) & (data['event_id'] == event_id) & \ + (data['ttl_input'] == ttl_input) subdata = data[keep] timestamps = subdata['timestamp'] labels = subdata['event_string'].astype('U') @@ -406,7 +410,7 @@ def read_ncs_files(self, ncs_filenames): deltas0 = np.diff(timestamps0) # It should be that: - #gap_indexes, = np.nonzero(deltas0!=good_delta) + # gap_indexes, = np.nonzero(deltas0!=good_delta) # but for a file I have found many deltas0==15999 deltas0==16000 # I guess this is a round problem @@ -437,17 +441,19 @@ def read_ncs_files(self, ncs_filenames): i0 = gap_bounds[seg_index] i1 = gap_bounds[seg_index + 1] - assert data[i0]['timestamp'] == data0[i0]['timestamp'], 'ncs files do not have the same gaps' + assert data[i0]['timestamp'] == data0[i0][ + 'timestamp'], 'ncs files do not have the same gaps' assert data[i1 - 1]['timestamp'] == data0[i1 - - 1]['timestamp'], 'ncs files do not have the same gaps' + 1][ + 'timestamp'], 'ncs files do not have the same gaps' subdata = data[i0:i1] self._sigs_memmap[seg_index][chan_id] = subdata if chan_id == chan_id0: ts0 = subdata[0]['timestamp'] - ts1 = subdata[-1]['timestamp'] + \ - np.uint64(BLOCK_SIZE / self._sigs_sampling_rate * 1e6) + ts1 = (subdata[-1]['timestamp'] + + np.uint64(BLOCK_SIZE / self._sigs_sampling_rate * 1e6)) self._timestamp_limits.append((ts0, ts1)) t_start = ts0 / 1e6 self._sigs_t_start.append(t_start) @@ -506,7 +512,7 @@ def read_ncs_files(self, ncs_filenames): ('TimeClosed', '', None), ('ApplicationName Cheetah', 'version', None), # used possibilty 2 for version ('AcquisitionSystem', '', None), - ('ReferenceChannel', '', None), + ('ReferenceChannel', '', None), ] @@ -569,8 +575,8 @@ def read_txt_header(filename): if 'channel_id' not in info: info['channel_id'] = name - #~ for k, v in info.items(): - #~ print(' ', k, ':', v) + # ~ for k, v in info.items(): + # ~ print(' ', k, ':', v) return info diff --git a/neo/rawio/neuroexplorerrawio.py b/neo/rawio/neuroexplorerrawio.py index 049c4355e..0d6591b00 100644 --- a/neo/rawio/neuroexplorerrawio.py +++ b/neo/rawio/neuroexplorerrawio.py @@ -94,7 +94,7 @@ def _parse_header(self): gain = entity_header['ADtoMV'] offset = entity_header['MVOffset'] group_id = 0 - sig_channels.append((name, _id, sampling_rate, dtype, units, + sig_channels.append((name, _id, sampling_rate, dtype, units, gain, offset, group_id)) self._sig_lengths.append(entity_header['NPointsWave']) # sig t_start is the first timestamp if datablock @@ -146,29 +146,29 @@ def _get_signal_t_start(self, block_index, seg_index, channel_indexes): assert len(channel_indexes) == 1, 'only one channel by one channel' return self._sig_t_starts[channel_indexes[0]] - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): assert len(channel_indexes) == 1, 'only one channel by one channel' channel_index = channel_indexes[0] entity_index = int(self.header['signal_channels'][channel_index]['id']) entity_header = self._entity_headers[entity_index] n = entity_header['n'] nb_sample = entity_header['NPointsWave'] - #offset = entity_header['offset'] - #timestamps = self._memmap[offset:offset+n*4].view('int32') - #offset2 = entity_header['offset'] + n*4 - #fragment_starts = self._memmap[offset2:offset2+n*4].view('int32') + # offset = entity_header['offset'] + # timestamps = self._memmap[offset:offset+n*4].view('int32') + # offset2 = entity_header['offset'] + n*4 + # fragment_starts = self._memmap[offset2:offset2+n*4].view('int32') offset3 = entity_header['offset'] + n * 4 + n * 4 raw_signal = self._memmap[offset3:offset3 + nb_sample * 2].view('int16') raw_signal = raw_signal[slice(i_start, i_stop), None] # 2D for compliance return raw_signal - def _spike_count(self, block_index, seg_index, unit_index): + def _spike_count(self, block_index, seg_index, unit_index): entity_index = int(self.header['unit_channels'][unit_index]['id']) entity_header = self._entity_headers[entity_index] nb_spike = entity_header['n'] return nb_spike - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): entity_index = int(self.header['unit_channels'][unit_index]['id']) entity_header = self._entity_headers[entity_index] n = entity_header['n'] @@ -201,6 +201,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, offset = entity_header['offset'] + n * 2 waveforms = self._memmap[offset:offset + n * 2 * width].view('int16') waveforms = waveforms.reshape(n, 1, width) + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -210,7 +211,7 @@ def _event_count(self, block_index, seg_index, event_channel_index): nb_event = entity_header['n'] return nb_event - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): entity_index = int(self.header['event_channels'][event_channel_index]['id']) entity_header = self._entity_headers[entity_index] diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index abfff68cf..b2d025bfc 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -46,7 +46,7 @@ def _source_name(self): def _parse_header(self): - #global header + # global header with open(self.filename, 'rb') as fid: offset0 = 0 global_header = read_as_dict(fid, GlobalHeader, offset=offset0) @@ -98,15 +98,15 @@ def _parse_header(self): block_pos[bl_type][chan_id].append(pos) pos += length - self._last_timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \ - 2 ** 32 + bl_header['TimeStamp'] + self._last_timestamps = (bl_header['UpperByteOf5ByteTimestamp'] * \ + 2 ** 32 + bl_header['TimeStamp']) - #... and finalize them in self._data_blocks + # ... and finalize them in self._data_blocks # for a faster acces depending on type (1, 4, 5) self._data_blocks = {} dt_base = [('pos', 'int64'), ('timestamp', 'int64'), ('size', 'int64')] dtype_by_bltype = { - #Spikes and waveforms + # Spikes and waveforms 1: np.dtype(dt_base + [('unit_id', 'uint16'), ('n1', 'uint16'), ('n2', 'uint16'), ]), # Events 4: np.dtype(dt_base + [('label', 'uint16'), ]), @@ -119,8 +119,8 @@ def _parse_header(self): bl_header = np.array(block_headers[bl_type][chan_id], dtype=DataBlockHeader) bl_pos = np.array(block_pos[bl_type][chan_id], dtype='int64') - timestamps = bl_header['UpperByteOf5ByteTimestamp'] * \ - 2 ** 32 + bl_header['TimeStamp'] + timestamps = (bl_header['UpperByteOf5ByteTimestamp'] * \ + 2 ** 32 + bl_header['TimeStamp']) n1 = bl_header['NumberOfWaveforms'] n2 = bl_header['NumberOfWordsInWaveform'] @@ -256,7 +256,7 @@ def _get_signal_size(self, block_index, seg_index, channel_indexes): def _get_signal_t_start(self, block_index, seg_index, channel_indexes): return 0. - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): if i_start is None: i_start = 0 if i_stop is None: @@ -276,7 +276,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, chan bl0 = np.searchsorted(data_blocks['cumsum'], i_start, side='left') bl1 = np.searchsorted(data_blocks['cumsum'], i_stop, side='left') ind = 0 - for bl in range(bl0, bl1): + for bl in range(bl0, bl1): ind0 = data_blocks[bl]['pos'] ind1 = data_blocks[bl]['size'] + ind0 data = self._memmap[ind0:ind1].view('int16') @@ -311,13 +311,13 @@ def _get_internal_mask(self, data_block, t_start, t_stop): return keep - def _spike_count(self, block_index, seg_index, unit_index): + def _spike_count(self, block_index, seg_index, unit_index): chan_id, unit_id = self.internal_unit_ids[unit_index] data_block = self._data_blocks[1][chan_id] nb_spike = np.sum(data_block['unit_id'] == unit_id) return nb_spike - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): chan_id, unit_id = self.internal_unit_ids[unit_index] data_block = self._data_blocks[1][chan_id] @@ -351,6 +351,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, ind1 = db['size'] + ind0 data = self._memmap[ind0:ind1].view('int16').reshape(n1, n2) waveforms[i, :, :] = data + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -359,7 +360,7 @@ def _event_count(self, block_index, seg_index, event_channel_index): nb_event = self._data_blocks[4][chan_id].size return nb_event - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): chan_id = int(self.header['event_channels'][event_channel_index]['id']) data_block = self._data_blocks[4][chan_id] keep = self._get_internal_mask(data_block, t_start, t_stop) diff --git a/neo/rawio/spike2rawio.py b/neo/rawio/spike2rawio.py index 08ef148ab..eba1c7829 100644 --- a/neo/rawio/spike2rawio.py +++ b/neo/rawio/spike2rawio.py @@ -8,8 +8,7 @@ http://www.neuro.ki.se/broberger/ and sonpy come from : - - SON Library 2.0 for MATLAB, written by Malcolm Lidierth at - King's College London. + - SON Library 2.0 for MATLAB, written by Malcolm Lidierth at King's College London. See http://www.kcl.ac.uk/depsta/biomedical/cfnr/lidierth.html This IO support old (v7) of spike2 @@ -85,8 +84,8 @@ def _parse_header(self): else: fid.seek(chan_info['firstblock']) block_info = read_as_dict(fid, blockHeaderDesciption) - chan_info['t_start'] = block_info['start_time'] * \ - info['us_per_time'] * info['dtime_base'] + chan_info['t_start'] = (block_info['start_time'] * \ + info['us_per_time'] * info['dtime_base']) self._channel_infos.append(chan_info) @@ -100,7 +99,7 @@ def _parse_header(self): ind = chan_info['firstblock'] for b in range(chan_info['blocks']): block_info = self._memmap[ind:ind + 20].view(blockHeaderDesciption)[0] - data_blocks.append((ind, block_info['items'], 0, + data_blocks.append((ind, block_info['items'], 0, block_info['start_time'], block_info['end_time'])) ind = block_info['succ_block'] @@ -111,7 +110,7 @@ def _parse_header(self): self._all_data_blocks[c] = data_blocks self._by_seg_data_blocks[c] = [] - + # For all signal channel detect gaps between data block (pause in rec) so new Segment. # then check that all channel have the same gaps. # this part is tricky because we need to check that all channel have same pause. @@ -150,30 +149,30 @@ def _parse_header(self): assert np.all(all_nb_seg[0]==all_nb_seg), \ 'Signal channel have differents pause so diffrents nb_segment' nb_segment = int(all_nb_seg[0]) - + for chan_id, gaps_block_ind in all_gaps_block_ind.items(): data_blocks = self._all_data_blocks[chan_id] self._sig_t_starts[chan_id] = [] self._sig_t_stops[chan_id] = [] - + for seg_ind in range(nb_segment): if seg_ind==0: fisrt_bl = 0 else: fisrt_bl = gaps_block_ind[seg_ind-1] + 1 self._sig_t_starts[chan_id].append(data_blocks[fisrt_bl]['start_time']) - + if seg_ind 0: # signal channel can different sampling_rate/dtype/t_start/signal_length... # grouping them is difficults, so each channe = one group - + sig_channels['group_id'] = np.arange(sig_channels.size) self._sig_dtypes = {s['group_id']: np.dtype(s['dtype']) for s in sig_channels} - + # fille into header dict self.header = {} self.header['nb_block'] = 1 @@ -323,7 +322,7 @@ def _get_signal_t_start(self, block_index, seg_index, channel_indexes): chan_id = self.header['signal_channels'][channel_indexes[0]]['id'] return self._sig_t_starts[chan_id][seg_index] * self._time_factor - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): if i_start is None: i_start = 0 if i_stop is None: @@ -349,7 +348,7 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, chan bl0 = np.searchsorted(data_blocks['cumsum'], i_start, side='left') bl1 = np.searchsorted(data_blocks['cumsum'], i_stop, side='left') ind = 0 - for bl in range(bl0, bl1): + for bl in range(bl0, bl1): ind0 = data_blocks[bl]['pos'] ind1 = data_blocks[bl]['size'] * dt.itemsize + ind0 data = self._memmap[ind0:ind1].view(dt) @@ -366,7 +365,8 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, chan raw_signals[ind:data.size + ind, c] = data ind += data.size return raw_signals - + + def _count_in_time_slice(self, seg_index, chan_id, lim0, lim1, marker_filter=None): # count event or spike in time slice data_blocks = self._all_data_blocks[chan_id] @@ -386,8 +386,8 @@ def _count_in_time_slice(self, seg_index, chan_id, lim0, lim1, marker_filter=Non if ts[-1] > lim1: break return nb - - def _get_internal_timestamp_(self, seg_index, chan_id, + + def _get_internal_timestamp_(self, seg_index, chan_id, t_start, t_stop, other_field=None, marker_filter=None): chan_info = self._channel_infos[chan_id] # data_blocks = self._by_seg_data_blocks[chan_id][seg_index] @@ -443,13 +443,13 @@ def _spike_count(self, block_index, seg_index, unit_index): if self.ced_units: marker_filter = unit_id else: - marker_filter = None + marker_filter = None lim0 = self._seg_t_starts[seg_index] lim1 = self._seg_t_stops[seg_index] return self._count_in_time_slice(seg_index, chan_id, lim0, lim1, marker_filter=marker_filter) - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): unit_header = self.header['unit_channels'][unit_index] chan_id, unit_id = self.internal_unit_ids[unit_index] @@ -477,10 +477,11 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, else: marker_filter = None - timestamps, waveforms = self._get_internal_timestamp_(seg_index, chan_id, + timestamps, waveforms = self._get_internal_timestamp_(seg_index, chan_id, t_start, t_stop, other_field='waveform', marker_filter=marker_filter) waveforms = waveforms.reshape(timestamps.size, 1, -1) + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -491,7 +492,7 @@ def _event_count(self, block_index, seg_index, event_channel_index): lim1 = self._seg_t_stops[seg_index] return self._count_in_time_slice(seg_index, chan_id, lim0, lim1, marker_filter=None) - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): event_header = self.header['event_channels'][event_channel_index] chan_id = int(event_header['id']) # because set to string in header chan_info = self._channel_infos[chan_id] @@ -533,8 +534,8 @@ def read_as_dict(fid, dtype): if dt[k].kind == 'S': v = v.decode('iso-8859-1') if len(v) > 0: - l = ord(v[0]) - v = v[1:l + 1] + m = ord(v[0]) + v = v[1:m + 1] info[k] = v return info diff --git a/neo/rawio/tdtrawio.py b/neo/rawio/tdtrawio.py index 34e2b8834..150796f91 100644 --- a/neo/rawio/tdtrawio.py +++ b/neo/rawio/tdtrawio.py @@ -41,7 +41,8 @@ def __init__(self, dirname='', sortname=''): """ 'sortname' is used to specify the external sortcode generated by offline spike sorting. if sortname=='PLX', there should be a ./sort/PLX/*.SortResult file in the tdt block, - which stores the sortcode for every spike; defaults to '', which uses the original online sort + which stores the sortcode for every spike; defaults to '', + which uses the original online sort """ BaseRawIO.__init__(self) if dirname.endswith('/'): @@ -111,7 +112,8 @@ def _parse_header(self): self._seg_t_stops.append(np.nan) print('segment stop time not found') - # If there exists an external sortcode in ./sort/[sortname]/*.SortResult (generated after offline sorting) + # If there exists an external sortcode in ./sort/[sortname]/*.SortResult + # (generated after offline sorting) if self.sortname is not '': try: for file in os.listdir(os.path.join(path, 'sort', sortname)): @@ -119,7 +121,7 @@ def _parse_header(self): sortresult_filename = os.path.join(path, 'sort', sortname, file) # get new sortcode newsortcode = np.fromfile(sortresult_filename, 'int8')[ - 1024:] # first 1024 bytes are header + 1024:] # first 1024 bytes are header # update the sort code with the info from this file tsq['sortcode'][1:-1] = newsortcode # print('sortcode updated') @@ -164,8 +166,8 @@ def _parse_header(self): for seg_index, segment_name in enumerate(segment_names): # get data index tsq = self._tsq[seg_index] - mask = (tsq['evtype'] == EVTYPE_STREAM) &\ - (tsq['evname'] == info['StoreName']) &\ + mask = (tsq['evtype'] == EVTYPE_STREAM) & \ + (tsq['evname'] == info['StoreName']) & \ (tsq['channel'] == chan_id) data_index = tsq[mask].copy() self._sigs_index[seg_index][chan_index] = data_index @@ -199,8 +201,9 @@ def _parse_header(self): # data buffer test if SEV file exists otherwise TEV path = os.path.join(self.dirname, segment_name) - sev_filename = os.path.join(path, tankname + '_' + segment_name + '_' + - info['StoreName'].decode('ascii') + '_ch' + str(chan_id) + '.sev') + fname = (tankname + '_' + segment_name + '_' + + info['StoreName'].decode('ascii') + '_ch' + str(chan_id) + '.sev') + sev_filename = os.path.join(path, fname) if os.path.exists(sev_filename): data = np.memmap(sev_filename, mode='r', offset=0, dtype='uint8') else: @@ -224,12 +227,13 @@ def _parse_header(self): unit_channels = [] keep = info_channel_groups['TankEvType'] == EVTYPE_SNIP tsq = np.hstack(self._tsq) - # If there is no chance the differet TSQ files will have different units, then we can do tsq = self._tsq[0] + # If there is no chance the differet TSQ files will have different units, + # then we can do tsq = self._tsq[0] for info in info_channel_groups[keep]: for c in range(info['NumChan']): chan_id = c + 1 - mask = (tsq['evtype'] == EVTYPE_SNIP) &\ - (tsq['evname'] == info['StoreName']) &\ + mask = (tsq['evtype'] == EVTYPE_SNIP) & \ + (tsq['evname'] == info['StoreName']) & \ (tsq['channel'] == chan_id) unit_ids = np.unique(tsq[mask]['sortcode']) for unit_id in unit_ids: @@ -293,7 +297,7 @@ def _get_signal_t_start(self, block_index, seg_index, channel_indexes): group_id = self.header['signal_channels'][channel_indexes[0]]['group_id'] return self._sigs_t_start[seg_index][group_id] - self._global_t_start - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, channel_indexes): # check of channel_indexes is same group_id is done outside (BaseRawIO) # so first is identique to others group_id = self.header['signal_channels'][channel_indexes[0]]['group_id'] @@ -354,15 +358,15 @@ def _get_mask(self, tsq, seg_index, evtype, evname, chan_id, unit_id, t_start, t return mask - def _spike_count(self, block_index, seg_index, unit_index): + def _spike_count(self, block_index, seg_index, unit_index): store_name, chan_id, unit_id = self.internal_unit_ids[unit_index] tsq = self._tsq[seg_index] - mask = self. _get_mask(tsq, seg_index, EVTYPE_SNIP, store_name, - chan_id, unit_id, None, None) + mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name, + chan_id, unit_id, None, None) nb_spike = np.sum(mask) return nb_spike - def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): + def _get_spike_timestamps(self, block_index, seg_index, unit_index, t_start, t_stop): store_name, chan_id, unit_id = self.internal_unit_ids[unit_index] tsq = self._tsq[seg_index] mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name, @@ -379,8 +383,8 @@ def _rescale_spike_timestamp(self, spike_timestamps, dtype): def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, t_stop): store_name, chan_id, unit_id = self.internal_unit_ids[unit_index] tsq = self._tsq[seg_index] - mask = self. _get_mask(tsq, seg_index, EVTYPE_SNIP, store_name, - chan_id, unit_id, t_start, t_stop) + mask = self._get_mask(tsq, seg_index, EVTYPE_SNIP, store_name, + chan_id, unit_id, t_start, t_stop) nb_spike = np.sum(mask) data = self._tev_datas[seg_index] @@ -393,6 +397,7 @@ def _get_spike_raw_waveforms(self, block_index, seg_index, unit_index, t_start, ind0 = e['offset'] ind1 = ind0 + nb_sample * dt.itemsize waveforms[i, 0, :] = data[ind0:ind1].view(dt) + waveforms = np.moveaxis(waveforms, 2, 0) return waveforms @@ -401,16 +406,16 @@ def _event_count(self, block_index, seg_index, event_channel_index): store_name = h['name'].encode('ascii') tsq = self._tsq[seg_index] chan_id = 0 - mask = self. _get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None) + mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None) nb_event = np.sum(mask) return nb_event - def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): + def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop): h = self.header['event_channels'][event_channel_index] store_name = h['name'].encode('ascii') tsq = self._tsq[seg_index] chan_id = 0 - mask = self. _get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None) + mask = self._get_mask(tsq, seg_index, EVTYPE_STRON, store_name, chan_id, None, None, None) timestamps = tsq[mask]['timestamp'] timestamps -= self._global_t_start @@ -474,32 +479,31 @@ def read_tbk(tbk_filename): tsq_dtype = [ - ('size', 'int32'), # bytes 0-4 - ('evtype', 'int32'), # bytes 5-8 - ('evname', 'S4'), # bytes 9-12 - ('channel', 'uint16'), # bytes 13-14 - ('sortcode', 'uint16'), # bytes 15-16 - ('timestamp', 'float64'), # bytes 17-24 - ('offset', 'int64'), # bytes 25-32 - ('dataformat', 'int32'), # bytes 33-36 - ('frequency', 'float32'), # bytes 37-40 + ('size', 'int32'), # bytes 0-4 + ('evtype', 'int32'), # bytes 5-8 + ('evname', 'S4'), # bytes 9-12 + ('channel', 'uint16'), # bytes 13-14 + ('sortcode', 'uint16'), # bytes 15-16 + ('timestamp', 'float64'), # bytes 17-24 + ('offset', 'int64'), # bytes 25-32 + ('dataformat', 'int32'), # bytes 33-36 + ('frequency', 'float32'), # bytes 37-40 ] -EVTYPE_UNKNOWN = int('00000000', 16) # 0 -EVTYPE_STRON = int('00000101', 16) # 257 -EVTYPE_STROFF = int('00000102', 16) # 258 -EVTYPE_SCALAR = int('00000201', 16) # 513 -EVTYPE_STREAM = int('00008101', 16) # 33025 -EVTYPE_SNIP = int('00008201', 16) # 33281 -EVTYPE_MARK = int('00008801', 16) # 34817 -EVTYPE_HASDATA = int('00008000', 16) # 32768 -EVTYPE_UCF = int('00000010', 16) # 16 -EVTYPE_PHANTOM = int('00000020', 16) # 32 -EVTYPE_MASK = int('0000FF0F', 16) # 65295 -EVTYPE_INVALID_MASK = int('FFFF0000', 16) # 4294901760 -EVMARK_STARTBLOCK = int('0001', 16) # 1 -EVMARK_STOPBLOCK = int('0002', 16) # 2 - +EVTYPE_UNKNOWN = int('00000000', 16) # 0 +EVTYPE_STRON = int('00000101', 16) # 257 +EVTYPE_STROFF = int('00000102', 16) # 258 +EVTYPE_SCALAR = int('00000201', 16) # 513 +EVTYPE_STREAM = int('00008101', 16) # 33025 +EVTYPE_SNIP = int('00008201', 16) # 33281 +EVTYPE_MARK = int('00008801', 16) # 34817 +EVTYPE_HASDATA = int('00008000', 16) # 32768 +EVTYPE_UCF = int('00000010', 16) # 16 +EVTYPE_PHANTOM = int('00000020', 16) # 32 +EVTYPE_MASK = int('0000FF0F', 16) # 65295 +EVTYPE_INVALID_MASK = int('FFFF0000', 16) # 4294901760 +EVMARK_STARTBLOCK = int('0001', 16) # 1 +EVMARK_STOPBLOCK = int('0002', 16) # 2 data_formats = { 0: 'float32', @@ -520,7 +524,7 @@ def is_tdtblock(blockpath): file_ext = set(file_ext) tdt_ext = {'.tbk', '.tdx', '.tev', '.tsq'} - if file_ext >= tdt_ext: # if containing all the necessary files + if file_ext >= tdt_ext: # if containing all the necessary files return True else: return False diff --git a/neo/rawio/tests/rawio_compliance.py b/neo/rawio/tests/rawio_compliance.py index 74e204e50..c7bb36841 100644 --- a/neo/rawio/tests/rawio_compliance.py +++ b/neo/rawio/tests/rawio_compliance.py @@ -10,16 +10,14 @@ """ import time -if not hasattr(time, 'perf_counter'): - time.perf_counter = time.time import logging - import numpy as np - - from neo.rawio.baserawio import (_signal_channel_dtype, _unit_channel_dtype, _event_channel_dtype, _common_sig_characteristics) +if not hasattr(time, 'perf_counter'): + time.perf_counter = time.time + def print_class(reader): return reader.__class__.__name__ @@ -115,8 +113,10 @@ def iter_over_sig_chunks(reader, channel_indexes, chunksize=1024): for i in range(nb): i_start = i * chunksize i_stop = min((i + 1) * chunksize, sig_size) - raw_chunk = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, - i_start=i_start, i_stop=i_stop, channel_indexes=channel_indexes) + raw_chunk = reader.get_analogsignal_chunk(block_index=block_index, + seg_index=seg_index, + i_start=i_start, i_stop=i_stop, + channel_indexes=channel_indexes) yield raw_chunk @@ -140,7 +140,7 @@ def read_analogsignals(reader): for channel_indexes in channel_indexes_list: for raw_chunk in iter_over_sig_chunks(reader, channel_indexes, chunksize=1024): assert raw_chunk.ndim == 2 - #~ pass + # ~ pass for channel_indexes in channel_indexes_list: sr = reader.get_signal_sampling_rate(channel_indexes=channel_indexes) @@ -173,19 +173,24 @@ def read_analogsignals(reader): channel_ids2 = signal_ids[::2] raw_chunk0 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, - i_start=i_start, i_stop=i_stop, channel_indexes=channel_indexes2) + i_start=i_start, i_stop=i_stop, + channel_indexes=channel_indexes2) assert raw_chunk0.ndim == 2 assert raw_chunk0.shape[0] == i_stop assert raw_chunk0.shape[1] == len(channel_indexes2) if unique_chan_name: - raw_chunk1 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, - i_start=i_start, i_stop=i_stop, channel_names=channel_names2) + raw_chunk1 = reader.get_analogsignal_chunk(block_index=block_index, + seg_index=seg_index, + i_start=i_start, i_stop=i_stop, + channel_names=channel_names2) np.testing.assert_array_equal(raw_chunk0, raw_chunk1) if unique_chan_id: - raw_chunk2 = reader.get_analogsignal_chunk(block_index=block_index, seg_index=seg_index, - i_start=i_start, i_stop=i_stop, channel_ids=channel_ids2) + raw_chunk2 = reader.get_analogsignal_chunk(block_index=block_index, + seg_index=seg_index, + i_start=i_start, i_stop=i_stop, + channel_ids=channel_ids2) np.testing.assert_array_equal(raw_chunk0, raw_chunk2) # convert to float32/float64 @@ -231,8 +236,10 @@ def benchmark_speed_read_signals(reader): nb_samples += raw_chunk.shape[0] t1 = time.perf_counter() speed = (nb_samples * nb_sig) / (t1 - t0) / 1e6 - logging.info('{} read ({}signals x {}samples) in {:0.3f} s so speed {:0.3f} MSPS from {}'.format(print_class(reader), - nb_sig, nb_samples, t1 - t0, speed, reader.source_name())) + logging.info( + '{} read ({}signals x {}samples) in {:0.3f} s so speed {:0.3f} MSPS from {}'.format( + print_class(reader), + nb_sig, nb_samples, t1 - t0, speed, reader.source_name())) def read_spike_times(reader): @@ -252,8 +259,10 @@ def read_spike_times(reader): if nb_spike == 0: continue - spike_timestamp = reader.get_spike_timestamps(block_index=block_index, seg_index=seg_index, - unit_index=unit_index, t_start=None, t_stop=None) + spike_timestamp = reader.get_spike_timestamps(block_index=block_index, + seg_index=seg_index, + unit_index=unit_index, t_start=None, + t_stop=None) assert spike_timestamp.shape[0] == nb_spike, 'nb_spike {} != {}'.format( spike_timestamp.shape[0], nb_spike) @@ -265,8 +274,10 @@ def read_spike_times(reader): t_start = spike_times[1] - 0.001 t_stop = spike_times[1] + 0.001 - spike_timestamp2 = reader.get_spike_timestamps(block_index=block_index, seg_index=seg_index, - unit_index=unit_index, t_start=t_start, t_stop=t_stop) + spike_timestamp2 = reader.get_spike_timestamps(block_index=block_index, + seg_index=seg_index, + unit_index=unit_index, + t_start=t_start, t_stop=t_stop) assert spike_timestamp2.shape[0] == 1 spike_times2 = reader.rescale_spike_timestamp(spike_timestamp2, 'float64') @@ -290,11 +301,12 @@ def read_spike_waveforms(reader): continue raw_waveforms = reader.get_spike_raw_waveforms(block_index=block_index, - seg_index=seg_index, unit_index=unit_index, + seg_index=seg_index, + unit_index=unit_index, t_start=None, t_stop=None) if raw_waveforms is None: continue - assert raw_waveforms.shape[0] == nb_spike + assert raw_waveforms.shape[1] == nb_spike assert raw_waveforms.ndim == 3 for dt in ('float32', 'float64'): @@ -320,8 +332,9 @@ def read_events(reader): if nb_event == 0: continue - ev_timestamps, ev_durations, ev_labels = reader.get_event_timestamps(block_index=block_index, seg_index=seg_index, - event_channel_index=ev_chan) + ev_timestamps, ev_durations, ev_labels = reader.get_event_timestamps( + block_index=block_index, seg_index=seg_index, + event_channel_index=ev_chan) assert ev_timestamps.shape[0] == nb_event, 'Wrong shape {}, {}'.format( ev_timestamps.shape[0], nb_event) if ev_durations is not None: diff --git a/neo/rawio/tests/test_blackrockrawio.py b/neo/rawio/tests/test_blackrockrawio.py index 648089d8c..1061fa4b5 100644 --- a/neo/rawio/tests/test_blackrockrawio.py +++ b/neo/rawio/tests/test_blackrockrawio.py @@ -41,7 +41,7 @@ class TestBlackrockRawIO(BaseTestRawIO, unittest.TestCase, ): @unittest.skipUnless(HAVE_SCIPY, "requires scipy") def test_compare_blackrockio_with_matlabloader(self): """ - This test compares the output of ReachGraspIO.read_block() with the + This test compares the output of BlackRockIO.read_block() with the output generated by a Matlab implementation of a Blackrock file reader provided by the company. The output for comparison is provided in a .mat file created by the script create_data_matlab_blackrock.m. @@ -59,7 +59,7 @@ def test_compare_blackrockio_with_matlabloader(self): ts_ml = ml['ts'] # spike time stamps elec_ml = ml['el'] # spike electrodes unit_ml = ml['un'] # spike unit IDs - wf_ml = ml['wf'] # waveform unit 1 channel 1 + wf_ml = ml['wf'].T # waveform unit 1 channel 1 dimensions = (time, spike) mts_ml = ml['mts'] # marker time stamps mid_ml = ml['mid'] # marker IDs @@ -92,7 +92,7 @@ def test_compare_blackrockio_with_matlabloader(self): # Check waveforms of channel 1, unit 0 if channel_id == 1 and unit_id == 0: io_waveforms = reader.get_spike_raw_waveforms(unit_index=unit_index) - io_waveforms = io_waveforms[:, 0, :] # remove dim 1 + io_waveforms = io_waveforms[:, :, 0] # remove dim 2 (channel_id) assert_equal(io_waveforms, wf_ml) # Check if digital input port events are equal diff --git a/neo/test/coretest/test_spiketrain.py b/neo/test/coretest/test_spiketrain.py index ed5cec3b4..ba26391d9 100644 --- a/neo/test/coretest/test_spiketrain.py +++ b/neo/test/coretest/test_spiketrain.py @@ -41,9 +41,9 @@ def setUp(self): def test__get_fake_values(self): self.annotations['seed'] = 0 waveforms = get_fake_value('waveforms', pq.Quantity, seed=3, dim=3) - shape = waveforms.shape[0] + shape = waveforms.shape[1] times = get_fake_value('times', pq.Quantity, seed=0, dim=1, - shape=waveforms.shape[0]) + shape=waveforms.shape[1]) t_start = get_fake_value('t_start', pq.Quantity, seed=1, dim=0) t_stop = get_fake_value('t_stop', pq.Quantity, seed=2, dim=0) left_sweep = get_fake_value('left_sweep', pq.Quantity, seed=4, dim=0) @@ -815,6 +815,7 @@ def test_tstop_units_conversion(self): class TestSorting(unittest.TestCase): def test_sort(self): waveforms = np.array([[[0., 1.]], [[2., 3.]], [[4., 5.]]]) * pq.mV + waveforms = np.moveaxis(waveforms, 2, 0) train = SpikeTrain([3, 4, 5] * pq.s, waveforms=waveforms, name='n', t_stop=10.0) assert_neo_object_is_compliant(train) @@ -831,7 +832,7 @@ def test_sort(self): train.sort() assert_neo_object_is_compliant(train) assert_arrays_equal(train, [3, 4, 5] * pq.s) - assert_arrays_equal(train.waveforms, waveforms[[0, 2, 1]]) + assert_arrays_equal(train.waveforms, waveforms[:, [0, 2, 1], :]) self.assertEqual(train.name, 'n') self.assertEqual(train.t_start, 0.0 * pq.s) self.assertEqual(train.t_stop, 10.0 * pq.s) @@ -845,6 +846,7 @@ def setUp(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1, 2, 0) self.data1 = np.array([3, 4, 5]) self.data1quant = self.data1 * pq.s self.train1 = SpikeTrain(self.data1quant, waveforms=self.waveforms1, @@ -859,6 +861,7 @@ def test_slice(self): assert_arrays_equal(self.train1[1:2], result) targwaveforms = np.array([[[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -871,7 +874,7 @@ def test_slice(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[1:2], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:, 1:2, :], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_to_end(self): @@ -882,6 +885,7 @@ def test_slice_to_end(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -894,7 +898,7 @@ def test_slice_to_end(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[1:], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:, 1:, :], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_from_beginning(self): @@ -905,6 +909,7 @@ def test_slice_from_beginning(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -917,7 +922,7 @@ def test_slice_from_beginning(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[:2], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:, :2, :], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) def test_slice_negative_idxs(self): @@ -928,6 +933,7 @@ def test_slice_negative_idxs(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -940,7 +946,7 @@ def test_slice_negative_idxs(self): self.assertEqual(self.train1.t_stop, result.t_stop) # except we update the waveforms - assert_arrays_equal(self.train1.waveforms[:-1], result.waveforms) + assert_arrays_equal(self.train1.waveforms[:, :-1, :], result.waveforms) assert_arrays_equal(targwaveforms, result.waveforms) @@ -958,6 +964,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1, 2, 0) self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data1quant = self.data1 * pq.ms self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -981,6 +988,7 @@ def test_time_slice_typical(self): [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1006,6 +1014,7 @@ def test_time_slice_differnt_units(self): [4.1, 5.1]], [[6., 7.], [6.1, 7.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1067,7 +1076,7 @@ def test_time_slice_empty(self): t_stop = 70.0 * pq.ms result = train.time_slice(t_start, t_stop) assert_arrays_equal(train, result) - assert_arrays_equal(waveforms[:-1], result.waveforms) + assert_arrays_equal(waveforms[:, :-1, :], result.waveforms) # but keep everything else pristine assert_neo_object_is_compliant(result) @@ -1092,6 +1101,7 @@ def test_time_slice_none_stop(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1113,6 +1123,7 @@ def test_time_slice_none_start(self): [0.1, 1.1]], [[2., 3.], [2.1, 3.1]]]) * pq.mV + targwaveforms = np.moveaxis(targwaveforms, 2, 0) assert_arrays_equal(targwaveforms, result.waveforms) # but keep everything else pristine @@ -1159,6 +1170,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1, 2, 0) self.data1 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data1quant = self.data1 * pq.ms self.train1 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -1176,6 +1188,7 @@ def setUp(self): [8.1, 9.1]], [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms2 = np.moveaxis(self.waveforms2, 2, 0) self.data2 = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.data2quant = self.data1 * pq.ms self.train2 = SpikeTrain(self.data1quant, t_stop=10.0 * pq.ms, @@ -1253,18 +1266,13 @@ def test_incompatible_t_start(self): class TestDuplicateWithNewData(unittest.TestCase): def setUp(self): - self.waveforms = np.array([[[0., 1.], - [0.1, 1.1]], - [[2., 3.], - [2.1, 3.1]], - [[4., 5.], - [4.1, 5.1]], - [[6., 7.], - [6.1, 7.1]], - [[8., 9.], - [8.1, 9.1]], - [[10., 11.], - [10.1, 11.1]]]) * pq.mV + self.waveforms = np.array([[[0., 1.], [0.1, 1.1]], + [[2., 3.], [2.1, 3.1]], + [[4., 5.], [4.1, 5.1]], + [[6., 7.], [6.1, 7.1]], + [[8., 9.], [8.1, 9.1]], + [[10., 11.], [10.1, 11.1]]]) * pq.mV + self.waveforms = np.moveaxis(self.waveforms, 2, 0) self.data = np.array([0.1, 0.5, 1.2, 3.3, 6.4, 7]) self.dataquant = self.data * pq.ms self.train = SpikeTrain(self.dataquant, t_stop=10.0 * pq.ms, @@ -1540,6 +1548,7 @@ def setUp(self): [2.1, 3.1]], [[4., 5.], [4.1, 5.1]]]) * pq.mV + self.waveforms1 = np.moveaxis(self.waveforms1, 2, 0) self.t_start1 = 0.5 self.t_stop1 = 10.0 self.t_start1quant = self.t_start1 * pq.ms diff --git a/neo/test/generate_datasets.py b/neo/test/generate_datasets.py index 6f9dc6707..207a9b0e2 100644 --- a/neo/test/generate_datasets.py +++ b/neo/test/generate_datasets.py @@ -23,7 +23,6 @@ from neo.core.baseneo import _container_name - TEST_ANNOTATIONS = [1, 0, 1.5, "this is a test", datetime.fromtimestamp(424242424), None] @@ -41,7 +40,7 @@ def generate_one_simple_block(block_name='block_0', nb_segment=3, supported_objects=objects, **kws) bl.segments.append(seg) - #if RecordingChannel in objects: + # if RecordingChannel in objects: # populate_RecordingChannel(bl) bl.create_many_to_one_relationship() @@ -51,12 +50,12 @@ def generate_one_simple_block(block_name='block_0', nb_segment=3, def generate_one_simple_segment(seg_name='segment 0', supported_objects=[], nb_analogsignal=4, - t_start=0.*pq.s, - sampling_rate=10*pq.kHz, - duration=6.*pq.s, + t_start=0. * pq.s, + sampling_rate=10 * pq.kHz, + duration=6. * pq.s, nb_spiketrain=6, - spikerate_range=[.5*pq.Hz, 12*pq.Hz], + spikerate_range=[.5 * pq.Hz, 12 * pq.Hz], event_types={'stim': ['a', 'b', 'c', 'd'], @@ -83,34 +82,34 @@ def generate_one_simple_segment(seg_name='segment 0', if AnalogSignal in supported_objects: for a in range(nb_analogsignal): anasig = AnalogSignal(rand(int(sampling_rate * duration)), - sampling_rate=sampling_rate, t_start=t_start, - units=pq.mV, channel_index=a, - name='sig %d for segment %s' % (a, seg.name)) + sampling_rate=sampling_rate, t_start=t_start, + units=pq.mV, channel_index=a, + name='sig %d for segment %s' % (a, seg.name)) seg.analogsignals.append(anasig) if SpikeTrain in supported_objects: for s in range(nb_spiketrain): - spikerate = rand()*np.diff(spikerate_range) + spikerate = rand() * np.diff(spikerate_range) spikerate += spikerate_range[0].magnitude - #spikedata = rand(int((spikerate*duration).simplified))*duration - #sptr = SpikeTrain(spikedata, + # spikedata = rand(int((spikerate*duration).simplified))*duration + # sptr = SpikeTrain(spikedata, # t_start=t_start, t_stop=t_start+duration) # #, name = 'spiketrain %d'%s) - spikes = rand(int((spikerate*duration).simplified)) + spikes = rand(int((spikerate * duration).simplified)) spikes.sort() # spikes are supposed to be an ascending sequence - sptr = SpikeTrain(spikes*duration, - t_start=t_start, t_stop=t_start+duration) + sptr = SpikeTrain(spikes * duration, + t_start=t_start, t_stop=t_start + duration) sptr.annotations['channel_index'] = s seg.spiketrains.append(sptr) if Event in supported_objects: for name, labels in event_types.items(): - evt_size = rand()*np.diff(event_size_range) + evt_size = rand() * np.diff(event_size_range) evt_size += event_size_range[0] evt_size = int(evt_size) labels = np.array(labels, dtype='S') - labels = labels[(rand(evt_size)*len(labels)).astype('i')] - evt = Event(times=rand(evt_size)*duration, labels=labels) + labels = labels[(rand(evt_size) * len(labels)).astype('i')] + evt = Event(times=rand(evt_size) * duration, labels=labels) seg.events.append(evt) if Epoch in supported_objects: @@ -120,12 +119,12 @@ def generate_one_simple_segment(seg_name='segment 0', durations = [] while t < duration: times.append(t) - dur = rand()*np.diff(epoch_duration_range) + dur = rand() * np.diff(epoch_duration_range) dur += epoch_duration_range[0] durations.append(dur) - t = t+dur + t = t + dur labels = np.array(labels, dtype='S') - labels = labels[(rand(len(times))*len(labels)).astype('i')] + labels = labels[(rand(len(times)) * len(labels)).astype('i')] epc = Epoch(times=pq.Quantity(times, units=pq.s), durations=pq.Quantity([x[0] for x in durations], units=pq.s), @@ -140,7 +139,7 @@ def generate_one_simple_segment(seg_name='segment 0', def generate_from_supported_objects(supported_objects): - #~ create_many_to_one_relationship + # ~ create_many_to_one_relationship if not supported_objects: raise ValueError('No objects specified') objects = supported_objects @@ -149,12 +148,12 @@ def generate_from_supported_objects(supported_objects): # Chris we do not create RC and RCG if it is not in objects # there is a test in generate_one_simple_block so I removed - #finalize_block(higher) + # finalize_block(higher) elif Segment in objects: higher = generate_one_simple_segment(supported_objects=objects) else: - #TODO + # TODO return None higher.create_many_to_one_relationship() @@ -194,8 +193,8 @@ def get_fake_value(name, datatype, dim=0, dtype='float', seed=None, return np.random.randint(100) if datatype == float: return 1000. * np.random.random() - if datatype == datetime: - return datetime.fromtimestamp(1000000000*np.random.random()) + if datatype == datetime: + return datetime.fromtimestamp(1000000000 * np.random.random()) if (name in ['t_start', 't_stop', 'sampling_rate'] and (datatype != pq.Quantity or dim)): @@ -219,27 +218,27 @@ def get_fake_value(name, datatype, dim=0, dtype='float', seed=None, if name == 'sampling_rate': data = np.array(10000.0) elif name == 't_start': - data = np.array(0.0) + data = np.array(0.0) elif name == 't_stop': - data = np.array(1.0) + data = np.array(1.0) elif n and name == 'channel_indexes': - data = np.arange(n) + data = np.arange(n) elif n and name == 'channel_names': - data = np.array(["ch%d" % i for i in range(n)]) + data = np.array(["ch%d" % i for i in range(n)]) elif n and obj == 'AnalogSignal': if name == 'signal': size = [] for _ in range(int(dim)): size.append(np.random.randint(5) + 1) size[1] = n - data = np.random.random(size)*1000. + data = np.random.random(size) * 1000. else: size = [] for _ in range(int(dim)): - if shape is None : + if shape is None: if name == "times": size.append(5) - else : + else: size.append(np.random.randint(5) + 1) else: size.append(shape) @@ -249,7 +248,7 @@ def get_fake_value(name, datatype, dim=0, dtype='float', seed=None, data *= 1000. if np.dtype(dtype) != np.float64: data = data.astype(dtype) - + if datatype == np.ndarray: return data if datatype == list: @@ -273,7 +272,8 @@ def get_fake_values(cls, annotate=True, seed=None, n=None): If annotate is True (default), also add annotations to the values. """ - if hasattr(cls, 'lower'): # is this a test that cls is a string? better to use isinstance(cls, basestring), no? + # is this a test that cls is a string? better to use isinstance(cls, basestring), no? + if hasattr(cls, 'lower'): cls = class_by_name[cls] kwargs = {} # assign attributes @@ -283,22 +283,23 @@ def get_fake_values(cls, annotate=True, seed=None, n=None): else: iseed = None kwargs[attr[0]] = get_fake_value(*attr, seed=iseed, obj=cls, n=n) - - if 'waveforms' in kwargs : #everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[0] - if len(kwargs["times"]) != kwargs["waveforms"].shape[0] : - if len(kwargs["times"]) < kwargs["waveforms"].shape[0] : - - dif = kwargs["waveforms"].shape[0] - len(kwargs["times"]) - - new_times =[] - for i in kwargs["times"].magnitude : + + # everything here is to force the kwargs to have len(time) == kwargs["waveforms"].shape[1] + if 'waveforms' in kwargs: + if len(kwargs["times"]) != kwargs["waveforms"].shape[1]: + if len(kwargs["times"]) < kwargs["waveforms"].shape[1]: + + dif = kwargs["waveforms"].shape[1] - len(kwargs["times"]) + + new_times = [] + for i in kwargs["times"].magnitude: new_times.append(i) np.random.seed(0) new_times = np.concatenate([new_times, np.random.random(dif)]) kwargs["times"] = pq.Quantity(new_times, units=pq.ms) - else : - kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[0]] + else: + kwargs['times'] = kwargs['times'][:kwargs["waveforms"].shape[1]] if 'times' in kwargs and 'signal' in kwargs: kwargs['times'] = kwargs['times'][:len(kwargs['signal'])] @@ -351,7 +352,7 @@ def fake_neo(obj_type="Block", cascade=True, seed=None, n=1): # we create a few of each class for j in range(n): if seed is not None: - iseed = 10*seed+100*i+1000*j + iseed = 10 * seed + 100 * i + 1000 * j else: iseed = None child = fake_neo(obj_type=childname, cascade=cascade, @@ -362,7 +363,7 @@ def fake_neo(obj_type="Block", cascade=True, seed=None, n=1): # parent, don't create the object, we will import it from secondary # containers later if (cascade == 'block' and len(child._parent_objects) > 0 and - obj_type != child._parent_objects[-1]): + obj_type != child._parent_objects[-1]): continue getattr(obj, _container_name(childname)).append(child) @@ -377,7 +378,7 @@ def fake_neo(obj_type="Block", cascade=True, seed=None, n=1): for j, unit in enumerate(chx.units): for k, train in enumerate(unit.spiketrains): obj.segments[k].spiketrains.append(train) - #elif obj_type == 'ChannelIndex': + # elif obj_type == 'ChannelIndex': # inds = [] # names = [] # chinds = np.array([unit.channel_indexes[0] for unit in obj.units]) diff --git a/neo/test/iotest/test_blackrockio.py b/neo/test/iotest/test_blackrockio.py index 4bd4cded0..a62d91ffc 100644 --- a/neo/test/iotest/test_blackrockio.py +++ b/neo/test/iotest/test_blackrockio.py @@ -229,11 +229,11 @@ def test_compare_blackrockio_with_matlabloader_v21(self): # Compare waveforms matlab_wf = wf_ml[np.nonzero( - np.logical_and(elec_ml == channelid, unit_ml == unitid)), :][0] - # Atleast_2d as correction for waveforms that are saved - # in single dimension in SpikeTrain - # because only one waveform is available - assert_equal(np.atleast_2d(np.squeeze(st_i.waveforms).magnitude), matlab_wf) + np.logical_and(elec_ml == channelid, unit_ml == unitid)), :][0].T + wfs = np.squeeze(st_i.waveforms).magnitude + if len(wfs.shape) == 1: + wfs = wfs[:,None] # expanding to two dimensions in case of single waveform + assert_equal(wfs, matlab_wf) # Compare spike timestamps matlab_spikes = ts_ml[np.nonzero( diff --git a/neo/test/iotest/test_brainwaresrcio.py b/neo/test/iotest/test_brainwaresrcio.py index 9ef915174..d5a1a728f 100644 --- a/neo/test/iotest/test_brainwaresrcio.py +++ b/neo/test/iotest/test_brainwaresrcio.py @@ -258,6 +258,7 @@ def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen, t_stop = pq.Quantity(sweepLen, units=pq.ms, dtype=np.float32) trig2 = pq.Quantity(trig2, units=pq.ms, dtype=np.uint8) waveforms = pq.Quantity(shapes, dtype=np.int8, units=pq.mV) + waveforms = np.moveaxis(waveforms, 2, 0) sampling_period = pq.Quantity(ADperiod, units=pq.us) train = SpikeTrain(times=times, t_start=t_start, t_stop=t_stop, @@ -270,6 +271,14 @@ def proc_src_condition_unit_repetition(sweep, damaIndex, timeStamp, sweepLen, return train +def empty_waveform_dimension_correction(block): + for seg in block.segments: + for st in seg.spiketrains: + if st.waveforms is not None: + if 0 in st.waveforms.shape: + st.waveforms = np.moveaxis(st.waveforms, 2, 0) + + class BrainwareSrcIOTestCase(BaseTestIO, unittest.TestCase): ''' Unit test testcase for neo.io.BrainwareSrcIO @@ -325,6 +334,7 @@ def test_against_reference(self): if not refname: continue obj = self.read_file(filename=filename, readall=True)[0] + empty_waveform_dimension_correction(obj) refobj = proc_src(self.get_filename_path(refname)) try: assert_neo_object_is_compliant(obj) diff --git a/neo/test/iotest/test_neuralynxio.py b/neo/test/iotest/test_neuralynxio.py index 588a24db1..28360a90c 100644 --- a/neo/test/iotest/test_neuralynxio.py +++ b/neo/test/iotest/test_neuralynxio.py @@ -5,20 +5,12 @@ # needed for python 3 compatibility from __future__ import absolute_import - -import os -import sys -import re +import time import warnings - import unittest - -import numpy as np import quantities as pq - from neo.test.iotest.common_io_test import BaseTestIO from neo.core import * - from neo.io.neuralynxio import NeuralynxIO from neo.io.neuralynxio import NeuralynxIO as NewNeuralynxIO from neo.io.neuralynxio_v1 import NeuralynxIO as OldNeuralynxIO @@ -88,7 +80,7 @@ def test_read_block(self): block = nio.read_block(load_waveforms=True) self.assertEqual(len(block.segments[0].analogsignals), 1) self.assertEqual(len(block.segments[0].spiketrains), 2) - self.assertEqual(block.segments[0].spiketrains[0].waveforms.shape[0], + self.assertEqual(block.segments[0].spiketrains[0].waveforms.shape[1], block.segments[0].spiketrains[0].shape[0]) self.assertGreater(len(block.segments[0].events), 0) @@ -157,7 +149,6 @@ def test_read_block(self): self.assertEqual(len(block.channel_indexes), 1) def test_read_segment(self): - dirname = self.get_filename_path('Cheetah_v5.7.4/original_data') nio = NeuralynxIO(dirname=dirname, use_cache=False) @@ -205,15 +196,10 @@ def test_gap_handling(self): # All above must delete before merging to master # the purpose is to test Old and New NeuralynxIO - -import time - - def compare_old_and_new_neuralynxio(): - base = '/tmp/files_for_testing_neo/neuralynx/' dirname = base + 'Cheetah_v5.5.1/original_data/' - #~ dirname = base+'Cheetah_v5.7.4/original_data/' + # ~ dirname = base+'Cheetah_v5.7.4/original_data/' t0 = time.perf_counter() newreader = NewNeuralynxIO(dirname) @@ -281,7 +267,7 @@ def compare_annotations(anno1, anno2): warnings.warn('Different numbers of annotations! {} != {' '}\nSkipping further comparison of this ' 'annotation list.'.format( - anno1.keys(), anno2.keys())) + anno1.keys(), anno2.keys())) return assert anno1.keys() == anno2.keys() for key in anno1.keys(): @@ -297,8 +283,7 @@ def compare_attributes(child1, child2): continue if type(child1) == SpikeTrain and attr_name == 'times': continue - unequal = child1.__getattribute__(attr_name) != \ - child2.__getattribute__(attr_name) + unequal = child1.__getattribute__(attr_name) != child2.__getattribute__(attr_name) if hasattr(unequal, 'any'): unequal = unequal.any() if unequal: @@ -311,4 +296,4 @@ def compare_attributes(child1, child2): if __name__ == '__main__': unittest.main() - #~ compare_old_and_new_neuralynxio() + # ~ compare_old_and_new_neuralynxio() diff --git a/neo/test/iotest/test_nixio.py b/neo/test/iotest/test_nixio.py index 34a4c1dbf..cac30ae0d 100644 --- a/neo/test/iotest/test_nixio.py +++ b/neo/test/iotest/test_nixio.py @@ -750,7 +750,7 @@ def test_spiketrain_write(self): seg.spiketrains.append(spiketrain) self.write_and_compare([block]) - waveforms = self.rquant((3, 5, 10), pq.mV) + waveforms = self.rquant((10, 3, 5), pq.mV) spiketrain = SpikeTrain(times=[1, 1.1, 1.2] * pq.ms, t_stop=1.5 * pq.s, name="spikes with wf", description="spikes for waveform test",