Skip to content

Commit 461d552

Browse files
authored
[BUG] Fix NumPy int and floats being unrecognised in type checks (#108)
1 parent ba0d66a commit 461d552

File tree

10 files changed

+58
-48
lines changed

10 files changed

+58
-48
lines changed

src/pybispectra/general/general.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_compute_threenorm,
1010
_ProcessBispectrum,
1111
)
12-
from pybispectra.utils._utils import _compute_in_parallel
12+
from pybispectra.utils._utils import _compute_in_parallel, _int_like
1313

1414

1515
class _General(_ProcessBispectrum):
@@ -42,7 +42,7 @@ def _sort_indices(self, indices: tuple[tuple[int]] | None) -> None:
4242
for group_idcs in indices:
4343
if not isinstance(group_idcs, tuple):
4444
raise TypeError("Entries of `indices` must be tuples.")
45-
if any(not isinstance(idx, int) for idx in group_idcs):
45+
if any(not isinstance(idx, _int_like) for idx in group_idcs):
4646
raise TypeError("Entries for groups in `indices` must be ints.")
4747
if any(idx < 0 or idx >= self._n_chans for idx in group_idcs):
4848
raise ValueError(

src/pybispectra/tde/tde.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pybispectra.utils import ResultsTDE
1111
from pybispectra.utils._defaults import _precision
1212
from pybispectra.utils._process import _ProcessBispectrum
13-
from pybispectra.utils._utils import _compute_in_parallel
13+
from pybispectra.utils._utils import _compute_in_parallel, _number_like, _int_like
1414

1515

1616
class TDE(_ProcessBispectrum):
@@ -291,14 +291,14 @@ def _sort_freq_bands(
291291
fmax: int | float | tuple[int | float],
292292
) -> None:
293293
"""Sort inputs for the frequency bounds."""
294-
if not isinstance(fmin, (int, float, tuple)):
294+
if not isinstance(fmin, _number_like + (tuple,)):
295295
raise TypeError("`fmin` must be an int, float, or tuple.")
296-
if not isinstance(fmax, (int, float, tuple)):
296+
if not isinstance(fmax, _number_like + (tuple,)):
297297
raise TypeError("`fmax` must be an int, float, or tuple.")
298298

299-
if isinstance(fmin, (int, float)):
299+
if isinstance(fmin, _number_like):
300300
fmin = (fmin,)
301-
if isinstance(fmax, (int, float)):
301+
if isinstance(fmax, _number_like):
302302
fmax = (fmax,)
303303

304304
new_fmax = []
@@ -343,12 +343,12 @@ def _sort_metrics(
343343
"""Sort inputs for the form of results being requested."""
344344
if not isinstance(antisym, (bool, tuple)):
345345
raise TypeError("`antisym` must be a bool or tuple of bools.")
346-
if not isinstance(method, (int, tuple)):
346+
if not isinstance(method, _int_like + (tuple,)):
347347
raise TypeError("`method` must be an int or tuple of ints.")
348348

349349
if isinstance(antisym, bool):
350350
antisym = (antisym,)
351-
if isinstance(method, int):
351+
if isinstance(method, _int_like):
352352
method = (method,)
353353

354354
if any(not isinstance(entry, bool) for entry in antisym):
@@ -388,7 +388,7 @@ def _sort_indices(self, indices: tuple[tuple[int]] | None) -> None:
388388
for group_idcs in (seeds, targets):
389389
if not isinstance(group_idcs, tuple):
390390
raise TypeError("Entries of `indices` must be tuples.")
391-
if any(not isinstance(idx, int) for idx in group_idcs):
391+
if any(not isinstance(idx, _int_like) for idx in group_idcs):
392392
raise TypeError(
393393
"Entries for seeds and targets in `indices` must be ints."
394394
)

src/pybispectra/utils/_plot.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from matplotlib.figure import Figure
88
from matplotlib.ticker import ScalarFormatter, StrMethodFormatter
99

10+
from pybispectra.utils._utils import _int_like, _number_like
11+
1012

1113
class _PlotBase(ABC):
1214
"""Base class for plotting results.
@@ -52,22 +54,22 @@ def _sort_plot_inputs(
5254
"""
5355
if nodes is None:
5456
nodes = tuple(range(self.n_nodes))
55-
if not isinstance(nodes, (int, tuple)):
57+
if not isinstance(nodes, _int_like + (tuple,)):
5658
raise TypeError("`nodes` must be an int or tuple.")
5759
if isinstance(nodes, int):
5860
nodes = (nodes,)
59-
if not all(isinstance(con, int) for con in nodes):
61+
if not all(isinstance(con, _int_like) for con in nodes):
6062
raise TypeError("Entries of `nodes` must be ints.")
6163
if any(con >= self.n_nodes for con in nodes) or any(con < 0 for con in nodes):
6264
raise ValueError("The requested node is not present in the results.")
6365

64-
if not isinstance(n_rows, int) or not isinstance(n_cols, int):
66+
if not isinstance(n_rows, _int_like) or not isinstance(n_cols, _int_like):
6567
raise TypeError("`n_rows` and `n_cols` must be integers.")
6668
if n_rows < 1 or n_cols < 1:
6769
raise ValueError("`n_rows` and `n_cols` must be >= 1.")
6870

69-
if not isinstance(major_tick_intervals, (int, float)) or not isinstance(
70-
minor_tick_intervals, (int, float)
71+
if not isinstance(major_tick_intervals, _number_like) or not isinstance(
72+
minor_tick_intervals, _number_like
7173
):
7274
raise TypeError(
7375
"`major_tick_intervals` and `minor_tick_intervals` should be ints or "
@@ -1117,11 +1119,11 @@ def _sort_freq_band_inputs(self, freq_bands: int | tuple[int] | None) -> tuple[i
11171119
if freq_bands is None:
11181120
freq_bands = tuple(range(self._n_fbands))
11191121
else:
1120-
if not isinstance(freq_bands, (int, tuple)):
1122+
if not isinstance(freq_bands, _int_like + (tuple,)):
11211123
raise TypeError("`freq_bands` must be an int or tuple.")
1122-
if isinstance(freq_bands, int):
1124+
if isinstance(freq_bands, _int_like):
11231125
freq_bands = (freq_bands,)
1124-
if not all(isinstance(fband, int) for fband in freq_bands):
1126+
if not all(isinstance(fband, _int_like) for fband in freq_bands):
11251127
raise TypeError("Entries of `freq_bands` must be ints.")
11261128
if any(fband >= self._n_fbands for fband in freq_bands) or any(
11271129
fband < 0 for fband in freq_bands

src/pybispectra/utils/_process.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numba import njit
1010

1111
from pybispectra.utils._defaults import _precision
12-
from pybispectra.utils._utils import _fast_find_first
12+
from pybispectra.utils._utils import _fast_find_first, _int_like, _number_like
1313
from pybispectra.utils.results import _ResultsBase
1414

1515

@@ -70,7 +70,7 @@ def _sort_init_inputs(
7070
"`data` and `freqs` must contain the same number of frequencies."
7171
)
7272

73-
if not isinstance(sampling_freq, (int, float)):
73+
if not isinstance(sampling_freq, _number_like):
7474
raise TypeError("`sampling_freq` must be an int or a float.")
7575
if np.abs(freqs).max() > sampling_freq / 2:
7676
raise ValueError(
@@ -114,7 +114,7 @@ def _sort_indices(self, indices: tuple[tuple[int]] | None) -> None:
114114
for group_idcs in (seeds, targets):
115115
if not isinstance(group_idcs, tuple):
116116
raise TypeError("Entries of `indices` must be tuples.")
117-
if any(not isinstance(idx, int) for idx in group_idcs):
117+
if any(not isinstance(idx, _int_like) for idx in group_idcs):
118118
raise TypeError(
119119
"Entries for seeds and targets in `indices` must be ints."
120120
)
@@ -180,7 +180,7 @@ def _sort_freqs(
180180

181181
def _sort_parallelisation(self, n_jobs: int) -> None:
182182
"""Sort parallelisation inputs."""
183-
if not isinstance(n_jobs, int):
183+
if not isinstance(n_jobs, _int_like):
184184
raise TypeError("`n_jobs` must be an integer.")
185185
if n_jobs < 1 and n_jobs != -1:
186186
raise ValueError("`n_jobs` must be >= 1 or -1.")

src/pybispectra/utils/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
from pybispectra.utils._defaults import _precision
1010

1111

12+
# Aliases for type checking
13+
_int_like = (int, np.integer)
14+
_float_like = (float, np.floating)
15+
_number_like = _int_like + _float_like
16+
17+
1218
def _compute_in_parallel(
1319
func: callable,
1420
loop_kwargs: list[dict],

src/pybispectra/utils/ged.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from mne.time_frequency import csd_array_fourier, csd_array_multitaper
1111

1212
from pybispectra.utils._defaults import _precision
13-
from pybispectra.utils._utils import _create_mne_info
13+
from pybispectra.utils._utils import _create_mne_info, _int_like, _number_like
1414
from pybispectra.utils.utils import compute_rank
1515

1616

@@ -174,7 +174,7 @@ def _sort_init_inputs(self, data: np.ndarray, sampling_freq: float) -> None:
174174
if data.ndim != 3:
175175
raise ValueError("`data` must be a 3D array.")
176176

177-
if not isinstance(sampling_freq, (int, float)):
177+
if not isinstance(sampling_freq, _number_like):
178178
raise TypeError("`sampling_freq` must be an int or a float.")
179179
self.sampling_freq = sampling_freq
180180

@@ -190,14 +190,14 @@ def _sort_freq_bounds(
190190
) -> None:
191191
"""Sort frequency bound inputs."""
192192
if not isinstance(signal_bounds, tuple) or not all(
193-
isinstance(entry, (int, float)) for entry in signal_bounds
193+
isinstance(entry, _number_like) for entry in signal_bounds
194194
):
195195
raise TypeError("`signal_bounds` must be a tuple of ints or floats.")
196196
if not isinstance(noise_bounds, tuple) or not all(
197-
isinstance(entry, (int, float)) for entry in noise_bounds
197+
isinstance(entry, _number_like) for entry in noise_bounds
198198
):
199199
raise TypeError("`noise_bounds` must be a tuple of ints or floats.")
200-
if not isinstance(signal_noise_gap, (int, float)):
200+
if not isinstance(signal_noise_gap, _number_like):
201201
raise TypeError("`signal_noise_gap` must be an int or a float.")
202202

203203
if len(signal_bounds) != 2 or len(noise_bounds) != 2:
@@ -231,7 +231,7 @@ def _sort_bandpass_filter(self, bandpass_filter: bool) -> None:
231231

232232
def _sort_n_harmonics(self, n_harmonics: int) -> None:
233233
"""Sort harmonic use input."""
234-
if not isinstance(n_harmonics, int):
234+
if not isinstance(n_harmonics, _int_like):
235235
raise TypeError("`n_harmonics` must be an int.")
236236

237237
if n_harmonics < -1:
@@ -262,7 +262,7 @@ def _sort_indices(self, indices: tuple[int] | None) -> None:
262262
indices = tuple(np.arange(self._n_chans, dtype=np.int32).tolist())
263263

264264
if not isinstance(indices, tuple) or not all(
265-
isinstance(entry, int) for entry in indices
265+
isinstance(entry, _int_like) for entry in indices
266266
):
267267
raise TypeError("`indices` must be a tuple of ints.")
268268

@@ -280,7 +280,7 @@ def _sort_rank(self, rank: int | None) -> None:
280280
if rank is None:
281281
rank = compute_rank(self.data)
282282

283-
if not isinstance(rank, int):
283+
if not isinstance(rank, _int_like):
284284
raise TypeError("`rank` must be an int.")
285285

286286
if rank < 1 or rank > self._use_n_chans:
@@ -530,7 +530,7 @@ def fit_hpmax(
530530

531531
def _sort_parallelisation(self, n_jobs: int) -> int:
532532
"""Sort parallelisation inputs."""
533-
if not isinstance(n_jobs, int):
533+
if not isinstance(n_jobs, _int_like):
534534
raise TypeError("`n_jobs` must be an integer.")
535535
if n_jobs < 1 and n_jobs != -1:
536536
raise ValueError("`n_jobs` must be >= 1 or -1.")
@@ -844,7 +844,7 @@ def get_transformed_data(
844844
"transformed data."
845845
)
846846

847-
if not isinstance(min_ratio, (int, float)):
847+
if not isinstance(min_ratio, _number_like):
848848
raise TypeError("`min_ratio` must be an int or a float")
849849
if not isinstance(copy, bool):
850850
raise TypeError("`copy` must be a bool.")

src/pybispectra/utils/results.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from matplotlib.figure import Figure
77

88
from pybispectra.utils._plot import _PlotCFC, _PlotGeneral, _PlotTDE, _PlotWaveShape
9+
from pybispectra.utils._utils import _int_like
910

1011

1112
class _ResultsBase(ABC):
@@ -66,7 +67,7 @@ def _sort_indices_seeds_targets(self, indices: tuple[tuple[int]]) -> None:
6667
for group_idcs in (seeds, targets):
6768
if not isinstance(group_idcs, tuple):
6869
raise TypeError("Entries of `indices` must be tuples.")
69-
if any(not isinstance(idx, int) for idx in group_idcs):
70+
if any(not isinstance(idx, _int_like) for idx in group_idcs):
7071
raise TypeError(
7172
"Entries for seeds and targets in `indices` must be ints."
7273
)
@@ -85,7 +86,7 @@ def _sort_indices_channels(self, indices: tuple[int]) -> None:
8586
"""Sort ``indices`` with inputs format [channels]."""
8687
if not isinstance(indices, tuple):
8788
raise TypeError("`indices` must be a tuple.")
88-
if not all(isinstance(idx, int) for idx in indices):
89+
if not all(isinstance(idx, _int_like) for idx in indices):
8990
raise TypeError("Entries of `indices` must be ints.")
9091
if any(idx < 0 for idx in indices):
9192
raise ValueError("Entries of `indices` must be >= 0.")
@@ -103,7 +104,7 @@ def _sort_indices_kmn(self, indices: tuple[tuple[int]]) -> None:
103104
for group_idcs in indices:
104105
if not isinstance(group_idcs, tuple):
105106
raise TypeError("Entries of `indices` must be tuples.")
106-
if any(not isinstance(idx, int) for idx in group_idcs):
107+
if any(not isinstance(idx, _int_like) for idx in group_idcs):
107108
raise TypeError("Entries for groups in `indices` must be ints.")
108109
if any(idx < 0 for idx in group_idcs):
109110
raise ValueError(r"Entries for groups in `indices` must be >= 0.")

src/pybispectra/utils/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from pybispectra import __version__ as version
1313
from pybispectra.utils._defaults import _precision
14-
from pybispectra.utils._utils import _compute_in_parallel
14+
from pybispectra.utils._utils import _compute_in_parallel, _int_like, _number_like
1515

1616

1717
def compute_fft(
@@ -128,12 +128,12 @@ def _compute_fft_input_checks(
128128
if not np.isreal(data).all():
129129
raise ValueError("`data` must be real-valued.")
130130

131-
if not isinstance(sampling_freq, (int, float)):
131+
if not isinstance(sampling_freq, _number_like):
132132
raise TypeError("`sampling_freq` must be an int or a float.")
133133

134134
if n_points is None:
135135
n_points = data.shape[2]
136-
if not isinstance(n_points, int):
136+
if not isinstance(n_points, _int_like):
137137
raise TypeError("`n_points` must be an integer")
138138

139139
if not isinstance(window, str):
@@ -145,7 +145,7 @@ def _compute_fft_input_checks(
145145
else:
146146
window_func = np.hamming
147147

148-
if not isinstance(n_jobs, int):
148+
if not isinstance(n_jobs, _int_like):
149149
raise TypeError("`n_jobs` must be an integer.")
150150
if n_jobs < 1 and n_jobs != -1:
151151
raise ValueError("`n_jobs` must be >= 1 or -1.")
@@ -291,7 +291,7 @@ def _compute_tfr_input_checks(
291291
if data.ndim != 3:
292292
raise ValueError("`data` must be a 3D array.")
293293

294-
if not isinstance(sampling_freq, (int, float)):
294+
if not isinstance(sampling_freq, _number_like):
295295
raise TypeError("`sampling_freq` must be an int or a float.")
296296

297297
if not isinstance(freqs, np.ndarray):
@@ -315,7 +315,7 @@ def _compute_tfr_input_checks(
315315
else:
316316
tfr_func = time_frequency.tfr_array_multitaper
317317

318-
if not isinstance(n_cycles, (np.ndarray, int, float)):
318+
if not isinstance(n_cycles, _number_like + (np.ndarray,)):
319319
raise TypeError("`n_cycles` must be a NumPy array, an int, or a float.")
320320
if isinstance(n_cycles, np.ndarray):
321321
if n_cycles.shape != freqs.shape:
@@ -334,10 +334,10 @@ def _compute_tfr_input_checks(
334334
raise TypeError("`use_fft` must be a bool.")
335335

336336
if tfr_mode == "multitaper":
337-
if not isinstance(multitaper_time_bandwidth, (int, float)):
337+
if not isinstance(multitaper_time_bandwidth, _number_like):
338338
raise TypeError("`multitaper_time_bandwidth` must be an int or a float.")
339339

340-
if not isinstance(n_jobs, int):
340+
if not isinstance(n_jobs, _int_like):
341341
raise TypeError("`n_jobs` must be an integer.")
342342
if n_jobs < 1 and n_jobs != -1:
343343
raise ValueError("`n_jobs` must be >= 1 or -1.")
@@ -375,7 +375,7 @@ def compute_rank(data: np.ndarray, sv_tol: int | float = 1e-5) -> int:
375375
if data.ndim != 3:
376376
raise ValueError("`data` must be a 3D array.")
377377

378-
if not isinstance(sv_tol, (int, float)):
378+
if not isinstance(sv_tol, _number_like):
379379
raise TypeError("`sv_tol` must be a float or an int.")
380380

381381
singular_vals = np.linalg.svd(data, compute_uv=False).min(axis=0)

src/pybispectra/waveshape/waveshape.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
_ProcessBispectrum,
1010
)
1111
from pybispectra.utils.results import ResultsWaveShape
12-
from pybispectra.utils._utils import _compute_in_parallel
12+
from pybispectra.utils._utils import _compute_in_parallel, _int_like
1313

1414
np.seterr(divide="ignore", invalid="ignore") # no warning for NaN division
1515

@@ -125,7 +125,7 @@ def compute(
125125
where the resulting values lie in the range :math:`[-1, 1]`.
126126
127127
Bicoherence is computed for all values of ``f1s`` and ``f2s``.
128-
128+
129129
.. warning::
130130
For values of ``f1s`` higher than ``f2s`` or where ``f2s + f1s`` exceeds the
131131
Nyquist frequency, a :obj:`numpy.nan` value is returned.
@@ -165,7 +165,7 @@ def _sort_indices(self, indices: tuple[int]) -> None:
165165
indices = tuple(range(self._n_chans))
166166
if not isinstance(indices, tuple):
167167
raise TypeError("`indices` must be a tuple.")
168-
if any(not isinstance(idx, int) for idx in indices):
168+
if any(not isinstance(idx, _int_like) for idx in indices):
169169
raise TypeError("Entries of `indices` must be ints.")
170170

171171
if any(idx < 0 or idx >= self._n_chans for idx in indices):

0 commit comments

Comments
 (0)