@@ -134,7 +134,7 @@ class BaseRawIO:
134134
135135 rawmode = None # one key from possible_raw_modes
136136
137- def __init__ (self , use_cache : bool = False , cache_path = 'same_as_resource' , ** kargs ):
137+ def __init__ (self , use_cache : bool = False , cache_path = 'same_as_resource' , ** kargs ):
138138 """
139139 :TODO: Why multi-file would have a single filename is confusing here - shouldn't
140140 the name of this argument be filenames_list or filenames_base or similar?
@@ -474,7 +474,7 @@ def channel_id_to_index(self, stream_index: int, channel_ids: list[str]):
474474 channel_indexes = np .array ([chan_ids .index (chan_id ) for chan_id in channel_ids ])
475475 return channel_indexes
476476
477- def _get_channel_indexes (self , stream_index :int , channel_indexes : list [int ]| None , channel_names : list [str ]| None , channel_ids : list [str ]| None ):
477+ def _get_channel_indexes (self , stream_index : int , channel_indexes : list [int ] | None , channel_names : list [str ] | None , channel_ids : list [str ] | None ):
478478 """
479479 Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
480480 depending which is not None.
@@ -485,7 +485,7 @@ def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None
485485 channel_indexes = self .channel_id_to_index (stream_index , channel_ids )
486486 return channel_indexes
487487
488- def _get_stream_index_from_arg (self , stream_index_arg ):
488+ def _get_stream_index_from_arg (self , stream_index_arg : int | None ):
489489 if stream_index_arg is None :
490490 assert self .header ['signal_streams' ].size == 1
491491 stream_index = 0
@@ -494,7 +494,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
494494 stream_index = stream_index_arg
495495 return stream_index
496496
497- def get_signal_size (self , block_index : int , seg_index : int , stream_index : int | None = None ):
497+ def get_signal_size (self , block_index : int , seg_index : int , stream_index : int | None = None ):
498498 """
499499 Retrieve the length of a single section of the channels in a stream.
500500 :param block_index:
@@ -505,7 +505,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|No
505505 stream_index = self ._get_stream_index_from_arg (stream_index )
506506 return self ._get_signal_size (block_index , seg_index , stream_index )
507507
508- def get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int | None = None ):
508+ def get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int | None = None ):
509509 """
510510 Retrieve the t_start of a single section of the channels in a stream.
511511 :param block_index:
@@ -516,7 +516,7 @@ def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int
516516 stream_index = self ._get_stream_index_from_arg (stream_index )
517517 return self ._get_signal_t_start (block_index , seg_index , stream_index )
518518
519- def get_signal_sampling_rate (self , stream_index : int | None = None ):
519+ def get_signal_sampling_rate (self , stream_index : int | None = None ):
520520 """
521521 Retrieve sampling rate for a stream and all channels in that stream.
522522 :param stream_index:
@@ -529,9 +529,9 @@ def get_signal_sampling_rate(self, stream_index: int| None=None):
529529 sr = signal_channels [0 ]['sampling_rate' ]
530530 return float (sr )
531531
532- def get_analogsignal_chunk (self , block_index : int = 0 , seg_index : int = 0 , i_start : int | None = None , i_stop : int | None = None ,
533- stream_index : int | None = None , channel_indexes : list [int ]| None = None , channel_names : list [str ]| None = None ,
534- channel_ids : list [str ]| None = None , prefer_slice :bool = False ):
532+ def get_analogsignal_chunk (self , block_index : int = 0 , seg_index : int = 0 , i_start : int | None = None , i_stop : int | None = None ,
533+ stream_index : int | None = None , channel_indexes : list [int ] | None = None , channel_names : list [str ] | None = None ,
534+ channel_ids : list [str ] | None = None , prefer_slice : bool = False ):
535535 """
536536 Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
537537 section of a single channel of recording. The channels are chosen either by channel_names,
@@ -588,8 +588,8 @@ def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start
588588
589589 return raw_chunk
590590
591- def rescale_signal_raw_to_float (self , raw_signal : np .ndarray , dtype : np .dtype = 'float32' , stream_index : int | None = None ,
592- channel_indexes : list [int ]| None = None , channel_names : list [str ]| None = None , channel_ids : list [str ]| None = None ):
591+ def rescale_signal_raw_to_float (self , raw_signal : np .ndarray , dtype : np .dtype = 'float32' , stream_index : int | None = None ,
592+ channel_indexes : list [int ] | None = None , channel_names : list [str ] | None = None , channel_ids : list [str ] | None = None ):
593593 """
594594 Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
595595 returned by a call to get_analogsignal_chunk. The channels are specified either by
@@ -628,11 +628,11 @@ def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='f
628628 return float_signal
629629
630630 # spiketrain and unit zone
631- def spike_count (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index :int = 0 ):
631+ def spike_count (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ):
632632 return self ._spike_count (block_index , seg_index , spike_channel_index )
633633
634- def get_spike_timestamps (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ,
635- t_start = None , t_stop = None ):
634+ def get_spike_timestamps (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ,
635+ t_start : float | None = None , t_stop : float | None = None ):
636636 """
637637 The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
638638 Sometimes it is the index on the signal but not always.
@@ -644,21 +644,21 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
644644 spike_channel_index , t_start , t_stop )
645645 return timestamp
646646
647- def rescale_spike_timestamp (self , spike_timestamps : np .ndarray , dtype : np .dtype = 'float64' ):
647+ def rescale_spike_timestamp (self , spike_timestamps : np .ndarray , dtype : np .dtype = 'float64' ):
648648 """
649649 Rescale spike timestamps to seconds.
650650 """
651651 return self ._rescale_spike_timestamp (spike_timestamps , dtype )
652652
653653 # spiketrain waveform zone
654- def get_spike_raw_waveforms (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ,
655- t_start : float | None = None , t_stop : float | None = None ):
654+ def get_spike_raw_waveforms (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ,
655+ t_start : float | None = None , t_stop : float | None = None ):
656656 wf = self ._get_spike_raw_waveforms (block_index , seg_index ,
657657 spike_channel_index , t_start , t_stop )
658658 return wf
659659
660- def rescale_waveforms_to_float (self , raw_waveforms : np .ndarray , dtype : np .dtype = 'float32' ,
661- spike_channel_index : int = 0 ):
660+ def rescale_waveforms_to_float (self , raw_waveforms : np .ndarray , dtype : np .dtype = 'float32' ,
661+ spike_channel_index : int = 0 ):
662662 wf_gain = self .header ['spike_channels' ]['wf_gain' ][spike_channel_index ]
663663 wf_offset = self .header ['spike_channels' ]['wf_offset' ][spike_channel_index ]
664664
@@ -672,11 +672,11 @@ def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype
672672 return float_waveforms
673673
674674 # event and epoch zone
675- def event_count (self , block_index :int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ):
675+ def event_count (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ):
676676 return self ._event_count (block_index , seg_index , event_channel_index )
677677
678- def get_event_timestamps (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ,
679- t_start : float | None = None , t_stop : float | None = None ):
678+ def get_event_timestamps (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ,
679+ t_start : float | None = None , t_stop : float | None = None ):
680680 """
681681 The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
682682 Sometimes it is the index on the signal but not always.
@@ -694,21 +694,21 @@ def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_cha
694694 block_index , seg_index , event_channel_index , t_start , t_stop )
695695 return timestamp , durations , labels
696696
697- def rescale_event_timestamp (self , event_timestamps : np .ndarray , dtype : np .dtype = 'float64' ,
698- event_channel_index :int = 0 ):
697+ def rescale_event_timestamp (self , event_timestamps : np .ndarray , dtype : np .dtype = 'float64' ,
698+ event_channel_index :int = 0 ):
699699 """
700700 Rescale event timestamps to seconds.
701701 """
702702 return self ._rescale_event_timestamp (event_timestamps , dtype , event_channel_index )
703703
704- def rescale_epoch_duration (self , raw_duration : np .ndarray , dtype : np .dtype = 'float64' ,
705- event_channel_index :int = 0 ):
704+ def rescale_epoch_duration (self , raw_duration : np .ndarray , dtype : np .dtype = 'float64' ,
705+ event_channel_index : int = 0 ):
706706 """
707707 Rescale epoch raw duration to seconds.
708708 """
709709 return self ._rescale_epoch_duration (raw_duration , dtype , event_channel_index )
710710
711- def setup_cache (self , cache_path : 'home' | 'same_as_resource' , ** init_kargs ):
711+ def setup_cache (self , cache_path : 'home' | 'same_as_resource' , ** init_kargs ):
712712 try :
713713 import joblib
714714 except ImportError :
@@ -780,7 +780,7 @@ def _source_name(self):
780780 def _segment_t_start (self , block_index : int , seg_index : int ):
781781 raise (NotImplementedError )
782782
783- def _segment_t_stop (self , block_index :int , seg_index : int ):
783+ def _segment_t_stop (self , block_index : int , seg_index : int ):
784784 raise (NotImplementedError )
785785
786786 ###
@@ -801,8 +801,8 @@ def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: in
801801 """
802802 raise (NotImplementedError )
803803
804- def _get_analogsignal_chunk (self , block_index : int , seg_index : int , i_start : int | None , i_stop : int | None ,
805- stream_index : int , channel_indexes : list [int ]| None ):
804+ def _get_analogsignal_chunk (self , block_index : int , seg_index : int , i_start : int | None , i_stop : int | None ,
805+ stream_index : int , channel_indexes : list [int ] | None ):
806806 """
807807 Return the samples from a set of AnalogSignals indexed
808808 by stream_index and channel_indexes (local index inner stream).
@@ -820,7 +820,7 @@ def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: in
820820 raise (NotImplementedError )
821821
822822 def _get_spike_timestamps (self , block_index : int , seg_index : int ,
823- spike_channel_index : int , t_start : float | None , t_stop : float | None ):
823+ spike_channel_index : int , t_start : float | None , t_stop : float | None ):
824824 raise (NotImplementedError )
825825
826826 def _rescale_spike_timestamp (self , spike_timestamps : np .ndarray , dtype : np .dtype ):
@@ -829,15 +829,15 @@ def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype
829829 ###
830830 # spike waveforms zone
831831 def _get_spike_raw_waveforms (self , block_index : int , seg_index : int ,
832- spike_channel_index : int , t_start : float | None , t_stop : float | None ):
832+ spike_channel_index : int , t_start : float | None , t_stop : float | None ):
833833 raise (NotImplementedError )
834834
835835 ###
836836 # event and epoch zone
837837 def _event_count (self , block_index : int , seg_index : int , event_channel_index : int ):
838838 raise (NotImplementedError )
839839
840- def _get_event_timestamps (self , block_index : int , seg_index : int , event_channel_index : int , t_start : float | None , t_stop : float | None ):
840+ def _get_event_timestamps (self , block_index : int , seg_index : int , event_channel_index : int , t_start : float | None , t_stop : float | None ):
841841 raise (NotImplementedError )
842842
843843 def _rescale_event_timestamp (self , event_timestamps : np .ndarray , dtype : np .dtype ):
@@ -847,7 +847,7 @@ def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
847847 raise (NotImplementedError )
848848
849849
850- def pprint_vector (vector , lim : int = 8 ):
850+ def pprint_vector (vector , lim : int = 8 ):
851851 vector = np .asarray (vector )
852852 assert vector .ndim == 1
853853 if len (vector ) > lim :
0 commit comments