Skip to content

Commit 48ddbf4

Browse files
committed
initial typing in base
1 parent 25d0f1d commit 48ddbf4

File tree

3 files changed

+72
-66
lines changed

3 files changed

+72
-66
lines changed

neo/io/baseio.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
1111
If you want a model for developing a new IO start from exampleIO.
1212
"""
13+
from __future__ import annotations
14+
from pathlib import Path
1315

1416
try:
1517
from collections.abc import Sequence
@@ -96,7 +98,7 @@ class BaseIO:
9698

9799
mode = 'file' # or 'fake' or 'dir' or 'database'
98100

99-
def __init__(self, filename=None, **kargs):
101+
def __init__(self, filename: str | Path =None, **kargs):
100102
self.filename = str(filename)
101103
# create a logger for the IO class
102104
fullname = self.__class__.__module__ + '.' + self.__class__.__name__
@@ -111,7 +113,7 @@ def __init__(self, filename=None, **kargs):
111113
corelogger.addHandler(logging_handler)
112114

113115
######## General read/write methods #######################
114-
def read(self, lazy=False, **kargs):
116+
def read(self, lazy:bool=False, **kargs):
115117
"""
116118
Return all data from the file as a list of Blocks
117119
"""

neo/rawio/baserawio.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
constructions of a RawIO for a given set of data.
6868
6969
"""
70+
from __future__ import annotations
7071

7172
import logging
7273
import numpy as np
@@ -133,7 +134,7 @@ class BaseRawIO:
133134

134135
rawmode = None # one key from possible_raw_modes
135136

136-
def __init__(self, use_cache=False, cache_path='same_as_resource', **kargs):
137+
def __init__(self, use_cache: bool=False, cache_path='same_as_resource', **kargs):
137138
"""
138139
:TODO: Why multi-file would have a single filename is confusing here - shouldn't
139140
the name of this argument be filenames_list or filenames_base or similar?
@@ -369,7 +370,7 @@ def block_count(self):
369370
"""return number of blocks"""
370371
return self.header['nb_block']
371372

372-
def segment_count(self, block_index):
373+
def segment_count(self, block_index: int):
373374
"""return number of segments for a given block"""
374375
return self.header['nb_segment'][block_index]
375376

@@ -379,7 +380,7 @@ def signal_streams_count(self):
379380
"""
380381
return len(self.header['signal_streams'])
381382

382-
def signal_channels_count(self, stream_index):
383+
def signal_channels_count(self, stream_index: int):
383384
"""Return the number of signal channels for a given stream.
384385
This number is the same for all Blocks and Segments.
385386
"""
@@ -400,7 +401,7 @@ def event_channels_count(self):
400401
"""
401402
return len(self.header['event_channels'])
402403

403-
def segment_t_start(self, block_index, seg_index):
404+
def segment_t_start(self, block_index: int, seg_index: int):
404405
"""Global t_start of a Segment in s. Shared by all objects except
405406
for AnalogSignal.
406407
"""
@@ -445,7 +446,7 @@ def _check_stream_signal_channel_characteristics(self):
445446

446447
self._several_channel_groups = signal_streams.size > 1
447448

448-
def channel_name_to_index(self, stream_index, channel_names):
449+
def channel_name_to_index(self, stream_index: int, channel_names: list[str]):
449450
"""
450451
Inside a stream, transform channel_names to channel_indexes.
451452
Based on self.header['signal_channels']
@@ -459,7 +460,7 @@ def channel_name_to_index(self, stream_index, channel_names):
459460
channel_indexes = np.array([chan_names.index(name) for name in channel_names])
460461
return channel_indexes
461462

462-
def channel_id_to_index(self, stream_index, channel_ids):
463+
def channel_id_to_index(self, stream_index: int, channel_ids: list[str]):
463464
"""
464465
Inside a stream, transform channel_ids to channel_indexes.
465466
Based on self.header['signal_channels']
@@ -473,7 +474,7 @@ def channel_id_to_index(self, stream_index, channel_ids):
473474
channel_indexes = np.array([chan_ids.index(chan_id) for chan_id in channel_ids])
474475
return channel_indexes
475476

476-
def _get_channel_indexes(self, stream_index, channel_indexes, channel_names, channel_ids):
477+
def _get_channel_indexes(self, stream_index:int, channel_indexes: list[int]|None, channel_names: list[str]|None, channel_ids: list[str]|None):
477478
"""
478479
Select channel_indexes for a stream based on channel_indexes/channel_names/channel_ids
479480
depending which is not None.
@@ -493,7 +494,7 @@ def _get_stream_index_from_arg(self, stream_index_arg):
493494
stream_index = stream_index_arg
494495
return stream_index
495496

496-
def get_signal_size(self, block_index, seg_index, stream_index=None):
497+
def get_signal_size(self, block_index: int, seg_index: int, stream_index: int|None=None):
497498
"""
498499
Retrieve the length of a single section of the channels in a stream.
499500
:param block_index:
@@ -504,7 +505,7 @@ def get_signal_size(self, block_index, seg_index, stream_index=None):
504505
stream_index = self._get_stream_index_from_arg(stream_index)
505506
return self._get_signal_size(block_index, seg_index, stream_index)
506507

507-
def get_signal_t_start(self, block_index, seg_index, stream_index=None):
508+
def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int|None=None):
508509
"""
509510
Retrieve the t_start of a single section of the channels in a stream.
510511
:param block_index:
@@ -515,7 +516,7 @@ def get_signal_t_start(self, block_index, seg_index, stream_index=None):
515516
stream_index = self._get_stream_index_from_arg(stream_index)
516517
return self._get_signal_t_start(block_index, seg_index, stream_index)
517518

518-
def get_signal_sampling_rate(self, stream_index=None):
519+
def get_signal_sampling_rate(self, stream_index: int| None=None):
519520
"""
520521
Retrieve sampling rate for a stream and all channels in that stream.
521522
:param stream_index:
@@ -528,9 +529,9 @@ def get_signal_sampling_rate(self, stream_index=None):
528529
sr = signal_channels[0]['sampling_rate']
529530
return float(sr)
530531

531-
def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_stop=None,
532-
stream_index=None, channel_indexes=None, channel_names=None,
533-
channel_ids=None, prefer_slice=False):
532+
def get_analogsignal_chunk(self, block_index: int =0, seg_index: int =0, i_start: int|None =None, i_stop: int|None =None,
533+
stream_index: int|None =None, channel_indexes: list[int]|None =None, channel_names: list[str]|None =None,
534+
channel_ids: list[str]|None=None, prefer_slice:bool=False):
534535
"""
535536
Return a chunk of raw signal as a Numpy array. columns correspond to samples from a
536537
section of a single channel of recording. The channels are chosen either by channel_names,
@@ -587,8 +588,8 @@ def get_analogsignal_chunk(self, block_index=0, seg_index=0, i_start=None, i_sto
587588

588589
return raw_chunk
589590

590-
def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=None,
591-
channel_indexes=None, channel_names=None, channel_ids=None):
591+
def rescale_signal_raw_to_float(self, raw_signal: np.ndarray, dtype: np.dtype='float32', stream_index: int|None=None,
592+
channel_indexes: list[int]|None=None, channel_names: list[str]|None=None, channel_ids: list[str]|None=None):
592593
"""
593594
Rescale a chunk of raw signals which are provided as a Numpy array. These are normally
594595
returned by a call to get_analogsignal_chunk. The channels are specified either by
@@ -627,7 +628,7 @@ def rescale_signal_raw_to_float(self, raw_signal, dtype='float32', stream_index=
627628
return float_signal
628629

629630
# spiketrain and unit zone
630-
def spike_count(self, block_index=0, seg_index=0, spike_channel_index=0):
631+
def spike_count(self, block_index: int=0, seg_index: int=0, spike_channel_index:int =0):
631632
return self._spike_count(block_index, seg_index, spike_channel_index)
632633

633634
def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0,
@@ -643,21 +644,21 @@ def get_spike_timestamps(self, block_index=0, seg_index=0, spike_channel_index=0
643644
spike_channel_index, t_start, t_stop)
644645
return timestamp
645646

646-
def rescale_spike_timestamp(self, spike_timestamps, dtype='float64'):
647+
def rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype='float64'):
647648
"""
648649
Rescale spike timestamps to seconds.
649650
"""
650651
return self._rescale_spike_timestamp(spike_timestamps, dtype)
651652

652653
# spiketrain waveform zone
653-
def get_spike_raw_waveforms(self, block_index=0, seg_index=0, spike_channel_index=0,
654-
t_start=None, t_stop=None):
654+
def get_spike_raw_waveforms(self, block_index: int=0, seg_index: int=0, spike_channel_index: int=0,
655+
t_start: float|None =None, t_stop: float|None=None):
655656
wf = self._get_spike_raw_waveforms(block_index, seg_index,
656657
spike_channel_index, t_start, t_stop)
657658
return wf
658659

659-
def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
660-
spike_channel_index=0):
660+
def rescale_waveforms_to_float(self, raw_waveforms: np.ndarray, dtype: np.dtype ='float32',
661+
spike_channel_index: int =0):
661662
wf_gain = self.header['spike_channels']['wf_gain'][spike_channel_index]
662663
wf_offset = self.header['spike_channels']['wf_offset'][spike_channel_index]
663664

@@ -671,11 +672,11 @@ def rescale_waveforms_to_float(self, raw_waveforms, dtype='float32',
671672
return float_waveforms
672673

673674
# event and epoch zone
674-
def event_count(self, block_index=0, seg_index=0, event_channel_index=0):
675+
def event_count(self, block_index:int =0, seg_index: int =0, event_channel_index: int =0):
675676
return self._event_count(block_index, seg_index, event_channel_index)
676677

677-
def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0,
678-
t_start=None, t_stop=None):
678+
def get_event_timestamps(self, block_index: int =0, seg_index: int =0, event_channel_index: int =0,
679+
t_start: float|None =None, t_stop: float|None=None):
679680
"""
680681
The timestamp datatype is as close to the format itself. Sometimes float/int32/int64.
681682
Sometimes it is the index on the signal but not always.
@@ -693,21 +694,21 @@ def get_event_timestamps(self, block_index=0, seg_index=0, event_channel_index=0
693694
block_index, seg_index, event_channel_index, t_start, t_stop)
694695
return timestamp, durations, labels
695696

696-
def rescale_event_timestamp(self, event_timestamps, dtype='float64',
697-
event_channel_index=0):
697+
def rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype='float64',
698+
event_channel_index:int =0):
698699
"""
699700
Rescale event timestamps to seconds.
700701
"""
701702
return self._rescale_event_timestamp(event_timestamps, dtype, event_channel_index)
702703

703-
def rescale_epoch_duration(self, raw_duration, dtype='float64',
704-
event_channel_index=0):
704+
def rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype ='float64',
705+
event_channel_index:int =0):
705706
"""
706707
Rescale epoch raw duration to seconds.
707708
"""
708709
return self._rescale_epoch_duration(raw_duration, dtype, event_channel_index)
709710

710-
def setup_cache(self, cache_path, **init_kargs):
711+
def setup_cache(self, cache_path: 'home'|'same_as_resource', **init_kargs):
711712
try:
712713
import joblib
713714
except ImportError:
@@ -735,7 +736,7 @@ def setup_cache(self, cache_path, **init_kargs):
735736
dirname = os.path.dirname(resource_name)
736737
else:
737738
assert os.path.exists(cache_path), \
738-
'cache_path do not exists use "home" or "same_as_resource" to make this auto'
739+
'cache_path does not exists use "home" or "same_as_resource" to make this auto'
739740

740741
# the hash of the resource (dir of file) is done with filename+datetime
741742
# TODO make something more sophisticated when rawmode='one-dir' that use all
@@ -776,32 +777,32 @@ def _parse_header(self):
776777
def _source_name(self):
777778
raise (NotImplementedError)
778779

779-
def _segment_t_start(self, block_index, seg_index):
780+
def _segment_t_start(self, block_index: int, seg_index: int):
780781
raise (NotImplementedError)
781782

782-
def _segment_t_stop(self, block_index, seg_index):
783+
def _segment_t_stop(self, block_index:int , seg_index: int):
783784
raise (NotImplementedError)
784785

785786
###
786787
# signal and channel zone
787-
def _get_signal_size(self, block_index, seg_index, stream_index):
788+
def _get_signal_size(self, block_index: int, seg_index: int, stream_index: int):
788789
"""
789790
Return the size of a set of AnalogSignals indexed by channel_indexes.
790791
791792
All channels indexed must have the same size and t_start.
792793
"""
793794
raise (NotImplementedError)
794795

795-
def _get_signal_t_start(self, block_index, seg_index, stream_index):
796+
def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int):
796797
"""
797798
Return the t_start of a set of AnalogSignals indexed by channel_indexes.
798799
799800
All channels indexed must have the same size and t_start.
800801
"""
801802
raise (NotImplementedError)
802803

803-
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
804-
stream_index, channel_indexes):
804+
def _get_analogsignal_chunk(self, block_index: int, seg_index: int, i_start: int|None, i_stop: int|None,
805+
stream_index: int, channel_indexes: list[int]|None):
805806
"""
806807
Return the samples from a set of AnalogSignals indexed
807808
by stream_index and channel_indexes (local index inner stream).
@@ -815,38 +816,38 @@ def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop,
815816

816817
###
817818
# spiketrain and unit zone
818-
def _spike_count(self, block_index, seg_index, spike_channel_index):
819+
def _spike_count(self, block_index: int, seg_index: int, spike_channel_index: int):
819820
raise (NotImplementedError)
820821

821-
def _get_spike_timestamps(self, block_index, seg_index,
822-
spike_channel_index, t_start, t_stop):
822+
def _get_spike_timestamps(self, block_index: int, seg_index: int,
823+
spike_channel_index: int, t_start: float|None, t_stop: float|None):
823824
raise (NotImplementedError)
824825

825-
def _rescale_spike_timestamp(self, spike_timestamps, dtype):
826+
def _rescale_spike_timestamp(self, spike_timestamps: np.ndarray, dtype: np.dtype):
826827
raise (NotImplementedError)
827828

828829
###
829830
# spike waveforms zone
830-
def _get_spike_raw_waveforms(self, block_index, seg_index,
831-
spike_channel_index, t_start, t_stop):
831+
def _get_spike_raw_waveforms(self, block_index: int, seg_index: int,
832+
spike_channel_index: int, t_start: float|None, t_stop: float|None):
832833
raise (NotImplementedError)
833834

834835
###
835836
# event and epoch zone
836-
def _event_count(self, block_index, seg_index, event_channel_index):
837+
def _event_count(self, block_index: int, seg_index: int, event_channel_index: int):
837838
raise (NotImplementedError)
838839

839-
def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_start, t_stop):
840+
def _get_event_timestamps(self, block_index: int, seg_index: int, event_channel_index: int, t_start: float|None, t_stop: float|None):
840841
raise (NotImplementedError)
841842

842-
def _rescale_event_timestamp(self, event_timestamps, dtype):
843+
def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype):
843844
raise (NotImplementedError)
844845

845-
def _rescale_epoch_duration(self, raw_duration, dtype):
846+
def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
846847
raise (NotImplementedError)
847848

848849

849-
def pprint_vector(vector, lim=8):
850+
def pprint_vector(vector, lim: int=8):
850851
vector = np.asarray(vector)
851852
assert vector.ndim == 1
852853
if len(vector) > lim:

0 commit comments

Comments
 (0)