1010from mne .time_frequency import csd_array_fourier , csd_array_multitaper
1111
1212from pybispectra .utils ._defaults import _precision
13- from pybispectra .utils ._utils import _create_mne_info
13+ from pybispectra .utils ._utils import _create_mne_info , _int_like , _number_like
1414from pybispectra .utils .utils import compute_rank
1515
1616
@@ -174,7 +174,7 @@ def _sort_init_inputs(self, data: np.ndarray, sampling_freq: float) -> None:
174174 if data .ndim != 3 :
175175 raise ValueError ("`data` must be a 3D array." )
176176
177- if not isinstance (sampling_freq , ( int , float ) ):
177+ if not isinstance (sampling_freq , _number_like ):
178178 raise TypeError ("`sampling_freq` must be an int or a float." )
179179 self .sampling_freq = sampling_freq
180180
@@ -190,14 +190,14 @@ def _sort_freq_bounds(
190190 ) -> None :
191191 """Sort frequency bound inputs."""
192192 if not isinstance (signal_bounds , tuple ) or not all (
193- isinstance (entry , ( int , float ) ) for entry in signal_bounds
193+ isinstance (entry , _number_like ) for entry in signal_bounds
194194 ):
195195 raise TypeError ("`signal_bounds` must be a tuple of ints or floats." )
196196 if not isinstance (noise_bounds , tuple ) or not all (
197- isinstance (entry , ( int , float ) ) for entry in noise_bounds
197+ isinstance (entry , _number_like ) for entry in noise_bounds
198198 ):
199199 raise TypeError ("`noise_bounds` must be a tuple of ints or floats." )
200- if not isinstance (signal_noise_gap , ( int , float ) ):
200+ if not isinstance (signal_noise_gap , _number_like ):
201201 raise TypeError ("`signal_noise_gap` must be an int or a float." )
202202
203203 if len (signal_bounds ) != 2 or len (noise_bounds ) != 2 :
@@ -231,7 +231,7 @@ def _sort_bandpass_filter(self, bandpass_filter: bool) -> None:
231231
232232 def _sort_n_harmonics (self , n_harmonics : int ) -> None :
233233 """Sort harmonic use input."""
234- if not isinstance (n_harmonics , int ):
234+ if not isinstance (n_harmonics , _int_like ):
235235 raise TypeError ("`n_harmonics` must be an int." )
236236
237237 if n_harmonics < - 1 :
@@ -262,7 +262,7 @@ def _sort_indices(self, indices: tuple[int] | None) -> None:
262262 indices = tuple (np .arange (self ._n_chans , dtype = np .int32 ).tolist ())
263263
264264 if not isinstance (indices , tuple ) or not all (
265- isinstance (entry , int ) for entry in indices
265+ isinstance (entry , _int_like ) for entry in indices
266266 ):
267267 raise TypeError ("`indices` must be a tuple of ints." )
268268
@@ -280,7 +280,7 @@ def _sort_rank(self, rank: int | None) -> None:
280280 if rank is None :
281281 rank = compute_rank (self .data )
282282
283- if not isinstance (rank , int ):
283+ if not isinstance (rank , _int_like ):
284284 raise TypeError ("`rank` must be an int." )
285285
286286 if rank < 1 or rank > self ._use_n_chans :
@@ -530,7 +530,7 @@ def fit_hpmax(
530530
531531 def _sort_parallelisation (self , n_jobs : int ) -> int :
532532 """Sort parallelisation inputs."""
533- if not isinstance (n_jobs , int ):
533+ if not isinstance (n_jobs , _int_like ):
534534 raise TypeError ("`n_jobs` must be an integer." )
535535 if n_jobs < 1 and n_jobs != - 1 :
536536 raise ValueError ("`n_jobs` must be >= 1 or -1." )
@@ -844,7 +844,7 @@ def get_transformed_data(
844844 "transformed data."
845845 )
846846
847- if not isinstance (min_ratio , ( int , float ) ):
847+ if not isinstance (min_ratio , _number_like ):
848848 raise TypeError ("`min_ratio` must be an int or a float" )
849849 if not isinstance (copy , bool ):
850850 raise TypeError ("`copy` must be a bool." )
0 commit comments