6767constructions of a RawIO for a given set of data.
6868
6969"""
70+ from __future__ import annotations
7071
7172import logging
7273import numpy as np
@@ -133,7 +134,7 @@ class BaseRawIO:
133134
134135 rawmode = None # one key from possible_raw_modes
135136
136- def __init__ (self , use_cache = False , cache_path = 'same_as_resource' , ** kargs ):
137+ def __init__ (self , use_cache : bool = False , cache_path = 'same_as_resource' , ** kargs ):
137138 """
138139 :TODO: Why multi-file would have a single filename is confusing here - shouldn't
139140 the name of this argument be filenames_list or filenames_base or similar?
@@ -369,7 +370,7 @@ def block_count(self):
369370 """return number of blocks"""
370371 return self .header ['nb_block' ]
371372
372- def segment_count (self , block_index ):
373+ def segment_count (self , block_index : int ):
373374 """return number of segments for a given block"""
374375 return self .header ['nb_segment' ][block_index ]
375376
@@ -379,7 +380,7 @@ def signal_streams_count(self):
379380 """
380381 return len (self .header ['signal_streams' ])
381382
382- def signal_channels_count (self , stream_index ):
383+ def signal_channels_count (self , stream_index : int ):
383384 """Return the number of signal channels for a given stream.
384385 This number is the same for all Blocks and Segments.
385386 """
@@ -400,7 +401,7 @@ def event_channels_count(self):
400401 """
401402 return len (self .header ['event_channels' ])
402403
403- def segment_t_start (self , block_index , seg_index ):
404+ def segment_t_start (self , block_index : int , seg_index : int ):
404405 """Global t_start of a Segment in s. Shared by all objects except
405406 for AnalogSignal.
406407 """
@@ -445,7 +446,7 @@ def _check_stream_signal_channel_characteristics(self):
445446
446447 self ._several_channel_groups = signal_streams .size > 1
447448
448- def channel_name_to_index (self , stream_index , channel_names ):
449+ def channel_name_to_index (self , stream_index : int , channel_names : list [ str ] ):
449450 """
450451 Inside a stream, transform channel_names to channel_indexes.
451452 Based on self.header['signal_channels']
@@ -459,7 +460,7 @@ def channel_name_to_index(self, stream_index, channel_names):
459460 channel_indexes = np .array ([chan_names .index (name ) for name in channel_names ])
460461 return channel_indexes
461462
462- def channel_id_to_index (self , stream_index , channel_ids ):
463+ def channel_id_to_index (self , stream_index : int , channel_ids : list [ str ] ):
463464 """
464465 Inside a stream, transform channel_ids to channel_indexes.
465466 Based on self.header['signal_channels']
@@ -473,7 +474,7 @@ def channel_id_to_index(self, stream_index, channel_ids):
473474 channel_indexes = np .array ([chan_ids .index (chan_id ) for chan_id in channel_ids ])
474475 return channel_indexes
475476
476- def _get_channel_indexes (self , stream_index , channel_indexes , channel_names , channel_ids ):
477+ def _get_channel_indexes (self , stream_index : int , channel_indexes : list [ int ] | None , channel_names : list [ str ] | None , channel_ids : list [ str ] | None ):
477478 """
478479 Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
479480 depending which is not None.
@@ -493,7 +494,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
493494 stream_index = stream_index_arg
494495 return stream_index
495496
496- def get_signal_size (self , block_index , seg_index , stream_index = None ):
497+ def get_signal_size (self , block_index : int , seg_index : int , stream_index : int | None = None ):
497498 """
498499 Retrieve the length of a single section of the channels in a stream.
499500 :param block_index:
@@ -504,7 +505,7 @@ def get_signal_size(self, block_index, seg_index, stream_index=None):
504505 stream_index = self ._get_stream_index_from_arg (stream_index )
505506 return self ._get_signal_size (block_index , seg_index , stream_index )
506507
507- def get_signal_t_start (self , block_index , seg_index , stream_index = None ):
508+ def get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int | None = None ):
508509 """
509510 Retrieve the t_start of a single section of the channels in a stream.
510511 :param block_index:
@@ -515,7 +516,7 @@ def get_signal_t_start(self, block_index, seg_index, stream_index=None):
515516 stream_index = self ._get_stream_index_from_arg (stream_index )
516517 return self ._get_signal_t_start (block_index , seg_index , stream_index )
517518
518- def get_signal_sampling_rate (self , stream_index = None ):
519+ def get_signal_sampling_rate (self , stream_index : int | None = None ):
519520 """
520521 Retrieve sampling rate for a stream and all channels in that stream.
521522 :param stream_index:
@@ -528,9 +529,9 @@ def get_signal_sampling_rate(self, stream_index=None):
528529 sr = signal_channels [0 ]['sampling_rate' ]
529530 return float (sr )
530531
531- def get_analogsignal_chunk (self , block_index = 0 , seg_index = 0 , i_start = None , i_stop = None ,
532- stream_index = None , channel_indexes = None , channel_names = None ,
533- channel_ids = None , prefer_slice = False ):
532+ def get_analogsignal_chunk (self , block_index : int = 0 , seg_index : int = 0 , i_start : int | None = None , i_stop : int | None = None ,
533+ stream_index : int | None = None , channel_indexes : list [ int ] | None = None , channel_names : list [ str ] | None = None ,
534+ channel_ids : list [ str ] | None = None , prefer_slice : bool = False ):
534535 """
535536 Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
536537 section of a single channel of recording. The channels are chosen either by channel_names,
@@ -587,8 +588,8 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
587588
588589 return raw_chunk
589590
590- def rescale_signal_raw_to_float (self , raw_signal , dtype = 'float32' , stream_index = None ,
591- channel_indexes = None , channel_names = None , channel_ids = None ):
591+ def rescale_signal_raw_to_float (self , raw_signal : np . ndarray , dtype : np . dtype = 'float32' , stream_index : int | None = None ,
592+ channel_indexes : list [ int ] | None = None , channel_names : list [ str ] | None = None , channel_ids : list [ str ] | None = None ):
592593 """
593594 Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
594595 returned by a call to get_analogsignal_chunk. The channels are specified either by
@@ -627,7 +628,7 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
627628 return float_signal
628629
629630 # spiketrain and unit zone
630- def spike_count (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ):
631+ def spike_count (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ):
631632 return self ._spike_count (block_index , seg_index , spike_channel_index )
632633
633634 def get_spike_timestamps (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ,
@@ -643,21 +644,21 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
643644 spike_channel_index , t_start , t_stop )
644645 return timestamp
645646
646- def rescale_spike_timestamp (self , spike_timestamps , dtype = 'float64' ):
647+ def rescale_spike_timestamp (self , spike_timestamps : np . ndarray , dtype : np . dtype = 'float64' ):
647648 """
648649 Rescale spike timestamps to seconds.
649650 """
650651 return self ._rescale_spike_timestamp (spike_timestamps , dtype )
651652
652653 # spiketrain waveform zone
653- def get_spike_raw_waveforms (self , block_index = 0 , seg_index = 0 , spike_channel_index = 0 ,
654- t_start = None , t_stop = None ):
654+ def get_spike_raw_waveforms (self , block_index : int = 0 , seg_index : int = 0 , spike_channel_index : int = 0 ,
655+ t_start : float | None = None , t_stop : float | None = None ):
655656 wf = self ._get_spike_raw_waveforms (block_index , seg_index ,
656657 spike_channel_index , t_start , t_stop )
657658 return wf
658659
659- def rescale_waveforms_to_float (self , raw_waveforms , dtype = 'float32' ,
660- spike_channel_index = 0 ):
660+ def rescale_waveforms_to_float (self , raw_waveforms : np . ndarray , dtype : np . dtype = 'float32' ,
661+ spike_channel_index : int = 0 ):
661662 wf_gain = self .header ['spike_channels' ]['wf_gain' ][spike_channel_index ]
662663 wf_offset = self .header ['spike_channels' ]['wf_offset' ][spike_channel_index ]
663664
@@ -671,11 +672,11 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
671672 return float_waveforms
672673
673674 # event and epoch zone
674- def event_count (self , block_index = 0 , seg_index = 0 , event_channel_index = 0 ):
675+ def event_count (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ):
675676 return self ._event_count (block_index , seg_index , event_channel_index )
676677
677- def get_event_timestamps (self , block_index = 0 , seg_index = 0 , event_channel_index = 0 ,
678- t_start = None , t_stop = None ):
678+ def get_event_timestamps (self , block_index : int = 0 , seg_index : int = 0 , event_channel_index : int = 0 ,
679+ t_start : float | None = None , t_stop : float | None = None ):
679680 """
680681 The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
681682 Sometimes it is the index on the signal but not always.
@@ -693,21 +694,21 @@ def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0
693694 block_index , seg_index , event_channel_index , t_start , t_stop )
694695 return timestamp , durations , labels
695696
696- def rescale_event_timestamp (self , event_timestamps , dtype = 'float64' ,
697- event_channel_index = 0 ):
697+ def rescale_event_timestamp (self , event_timestamps : np . ndarray , dtype : np . dtype = 'float64' ,
698+ event_channel_index : int = 0 ):
698699 """
699700 Rescale event timestamps to seconds.
700701 """
701702 return self ._rescale_event_timestamp (event_timestamps , dtype , event_channel_index )
702703
703- def rescale_epoch_duration (self , raw_duration , dtype = 'float64' ,
704- event_channel_index = 0 ):
704+ def rescale_epoch_duration (self , raw_duration : np . ndarray , dtype : np . dtype = 'float64' ,
705+ event_channel_index : int = 0 ):
705706 """
706707 Rescale epoch raw duration to seconds.
707708 """
708709 return self ._rescale_epoch_duration (raw_duration , dtype , event_channel_index )
709710
710- def setup_cache (self , cache_path , ** init_kargs ):
711+ def setup_cache (self , cache_path : 'home' | 'same_as_resource' , ** init_kargs ):
711712 try :
712713 import joblib
713714 except ImportError :
@@ -735,7 +736,7 @@ def setup_cache(self, cache_path, **init_kargs):
735736 dirname = os .path .dirname (resource_name )
736737 else :
737738 assert os .path .exists (cache_path ), \
738- 'cache_path do not exists use "home" or "same_as_resource" to make this auto'
739+ 'cache_path does not exists use "home" or "same_as_resource" to make this auto'
739740
740741 # the hash of the resource (dir of file) is done with filename+datetime
741742 # TODO make something more sophisticated when rawmode='one-dir' that use all
@@ -776,32 +777,32 @@ def _parse_header(self):
776777 def _source_name (self ):
777778 raise (NotImplementedError )
778779
779- def _segment_t_start (self , block_index , seg_index ):
780+ def _segment_t_start (self , block_index : int , seg_index : int ):
780781 raise (NotImplementedError )
781782
782- def _segment_t_stop (self , block_index , seg_index ):
783+ def _segment_t_stop (self , block_index : int , seg_index : int ):
783784 raise (NotImplementedError )
784785
785786 ###
786787 # signal and channel zone
787- def _get_signal_size (self , block_index , seg_index , stream_index ):
788+ def _get_signal_size (self , block_index : int , seg_index : int , stream_index : int ):
788789 """
789790 Return the size of a set of AnalogSignals indexed by channel_indexes.
790791
791792 All channels indexed must have the same size and t_start.
792793 """
793794 raise (NotImplementedError )
794795
795- def _get_signal_t_start (self , block_index , seg_index , stream_index ):
796+ def _get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int ):
796797 """
797798 Return the t_start of a set of AnalogSignals indexed by channel_indexes.
798799
799800 All channels indexed must have the same size and t_start.
800801 """
801802 raise (NotImplementedError )
802803
803- def _get_analogsignal_chunk (self , block_index , seg_index , i_start , i_stop ,
804- stream_index , channel_indexes ):
804+ def _get_analogsignal_chunk (self , block_index : int , seg_index : int , i_start : int | None , i_stop : int | None ,
805+ stream_index : int , channel_indexes : list [ int ] | None ):
805806 """
806807 Return the samples from a set of AnalogSignals indexed
807808 by stream_index and channel_indexes (local index inner stream).
@@ -815,38 +816,38 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
815816
816817 ###
817818 # spiketrain and unit zone
818- def _spike_count (self , block_index , seg_index , spike_channel_index ):
819+ def _spike_count (self , block_index : int , seg_index : int , spike_channel_index : int ):
819820 raise (NotImplementedError )
820821
821- def _get_spike_timestamps (self , block_index , seg_index ,
822- spike_channel_index , t_start , t_stop ):
822+ def _get_spike_timestamps (self , block_index : int , seg_index : int ,
823+ spike_channel_index : int , t_start : float | None , t_stop : float | None ):
823824 raise (NotImplementedError )
824825
825- def _rescale_spike_timestamp (self , spike_timestamps , dtype ):
826+ def _rescale_spike_timestamp (self , spike_timestamps : np . ndarray , dtype : np . dtype ):
826827 raise (NotImplementedError )
827828
828829 ###
829830 # spike waveforms zone
830- def _get_spike_raw_waveforms (self , block_index , seg_index ,
831- spike_channel_index , t_start , t_stop ):
831+ def _get_spike_raw_waveforms (self , block_index : int , seg_index : int ,
832+ spike_channel_index : int , t_start : float | None , t_stop : float | None ):
832833 raise (NotImplementedError )
833834
834835 ###
835836 # event and epoch zone
836- def _event_count (self , block_index , seg_index , event_channel_index ):
837+ def _event_count (self , block_index : int , seg_index : int , event_channel_index : int ):
837838 raise (NotImplementedError )
838839
839- def _get_event_timestamps (self , block_index , seg_index , event_channel_index , t_start , t_stop ):
840+ def _get_event_timestamps (self , block_index : int , seg_index : int , event_channel_index : int , t_start : float | None , t_stop : float | None ):
840841 raise (NotImplementedError )
841842
842- def _rescale_event_timestamp (self , event_timestamps , dtype ):
843+ def _rescale_event_timestamp (self , event_timestamps : np . ndarray , dtype : np . dtype ):
843844 raise (NotImplementedError )
844845
845- def _rescale_epoch_duration (self , raw_duration , dtype ):
846+ def _rescale_epoch_duration (self , raw_duration : np . ndarray , dtype : np . dtype ):
846847 raise (NotImplementedError )
847848
848849
849- def pprint_vector (vector , lim = 8 ):
850+ def pprint_vector (vector , lim : int = 8 ):
850851 vector = np .asarray (vector )
851852 assert vector .ndim == 1
852853 if len (vector ) > lim :
0 commit comments