1
1
from __future__ import annotations
2
2
3
3
import contextlib
4
+ import typing
4
5
import warnings
5
- from typing import TYPE_CHECKING , Any
6
6
7
7
# array-api-strict#6
8
- import array_api_strict as xp # type: ignore[import-untyped]
8
+ import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
9
9
import numpy as np
10
10
import pytest
11
11
from numpy .testing import assert_allclose , assert_array_equal , assert_equal
12
12
13
13
from array_api_extra import atleast_nd , cov , create_diagonal , expand_dims , kron , sinc
14
14
15
- if TYPE_CHECKING :
16
- Array = Any # To be changed to a Protocol later (see array-api#589)
15
+ if typing . TYPE_CHECKING :
16
+ from array_api_extra . _typing import Array
17
17
18
18
19
19
class TestAtLeastND :
@@ -131,7 +131,7 @@ def test_1d(self):
131
131
132
132
@pytest .mark .parametrize ("n" , range (1 , 10 ))
133
133
@pytest .mark .parametrize ("offset" , range (1 , 10 ))
134
- def test_create_diagonal (self , n , offset ):
134
+ def test_create_diagonal (self , n : int , offset : int ):
135
135
# from scipy._lib tests
136
136
rng = np .random .default_rng (2347823 )
137
137
one = xp .asarray (1.0 )
@@ -180,9 +180,9 @@ def test_basic(self):
180
180
assert_array_equal (kron (a , b , xp = xp ), k )
181
181
182
182
def test_kron_smoke (self ):
183
- a = xp .ones ([ 3 , 3 ] )
184
- b = xp .ones ([ 3 , 3 ] )
185
- k = xp .ones ([ 9 , 9 ] )
183
+ a = xp .ones (( 3 , 3 ) )
184
+ b = xp .ones (( 3 , 3 ) )
185
+ k = xp .ones (( 9 , 9 ) )
186
186
187
187
assert_array_equal (kron (a , b , xp = xp ), k )
188
188
@@ -197,7 +197,7 @@ def test_kron_smoke(self):
197
197
((2 , 0 , 0 , 2 ), (2 , 0 , 2 )),
198
198
],
199
199
)
200
- def test_kron_shape (self , shape_a , shape_b ):
200
+ def test_kron_shape (self , shape_a : tuple [ int , ...], shape_b : tuple [ int , ...] ):
201
201
a = xp .ones (shape_a )
202
202
b = xp .ones (shape_b )
203
203
normalised_shape_a = xp .asarray (
@@ -271,7 +271,7 @@ def test_simple(self):
271
271
assert_allclose (w , xp .flip (w , axis = 0 ))
272
272
273
273
@pytest .mark .parametrize ("x" , [0 , 1 + 3j ])
274
- def test_dtype (self , x ):
274
+ def test_dtype (self , x : int | complex ):
275
275
with pytest .raises (ValueError , match = "real floating data type" ):
276
276
sinc (xp .asarray (x ), xp = xp )
277
277
0 commit comments