77import warnings
88from collections .abc import Sequence
99from types import ModuleType
10- from typing import cast
10+ from typing import TYPE_CHECKING , cast
1111
1212from ._at import at
1313from ._utils import _compat , _helpers
@@ -375,8 +375,8 @@ def expand_dims(
375375
376376
377377def isclose (
378- a : Array ,
379- b : Array ,
378+ a : Array | complex ,
379+ b : Array | complex ,
380380 * ,
381381 rtol : float = 1e-05 ,
382382 atol : float = 1e-08 ,
@@ -385,6 +385,10 @@ def isclose(
385385) -> Array : # numpydoc ignore=PR01,RT01
386386 """See docstring in array_api_extra._delegation."""
387387 a , b = asarrays (a , b , xp = xp )
388+ # FIXME https://github.com/microsoft/pyright/issues/10085
389+ if TYPE_CHECKING : # pragma: nocover
390+ assert _compat .is_array_api_obj (a )
391+ assert _compat .is_array_api_obj (b )
388392
389393 a_inexact = xp .isdtype (a .dtype , ("real floating" , "complex floating" ))
390394 b_inexact = xp .isdtype (b .dtype , ("real floating" , "complex floating" ))
@@ -419,7 +423,13 @@ def isclose(
419423 return xp .abs (a - b ) <= (atol + xp .abs (b ) // nrtol )
420424
421425
422- def kron (a : Array , b : Array , / , * , xp : ModuleType | None = None ) -> Array :
426+ def kron (
427+ a : Array | complex ,
428+ b : Array | complex ,
429+ / ,
430+ * ,
431+ xp : ModuleType | None = None ,
432+ ) -> Array :
423433 """
424434 Kronecker product of two arrays.
425435
@@ -495,9 +505,16 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
495505 if xp is None :
496506 xp = array_namespace (a , b )
497507 a , b = asarrays (a , b , xp = xp )
508+ # FIXME https://github.com/microsoft/pyright/issues/10085
509+ if TYPE_CHECKING : # pragma: nocover
510+ assert _compat .is_array_api_obj (a )
511+ assert _compat .is_array_api_obj (b )
498512
499513 singletons = (1 ,) * (b .ndim - a .ndim )
500514 a = xp .broadcast_to (a , singletons + a .shape )
515+ # FIXME https://github.com/microsoft/pyright/issues/10085
516+ if TYPE_CHECKING : # pragma: nocover
517+ assert _compat .is_array_api_obj (a )
501518
502519 nd_b , nd_a = b .ndim , a .ndim
503520 nd_max = max (nd_b , nd_a )
@@ -614,8 +631,8 @@ def pad(
614631
615632
616633def setdiff1d (
617- x1 : Array ,
618- x2 : Array ,
634+ x1 : Array | complex ,
635+ x2 : Array | complex ,
619636 / ,
620637 * ,
621638 assume_unique : bool = False ,
@@ -628,7 +645,7 @@ def setdiff1d(
628645
629646 Parameters
630647 ----------
631- x1 : array
648+ x1 : array | int | float | complex | bool
632649 Input array.
633650 x2 : array
634651 Input comparison array.
@@ -665,6 +682,11 @@ def setdiff1d(
665682 else :
666683 x1 = xp .unique_values (x1 )
667684 x2 = xp .unique_values (x2 )
685+
686+ # FIXME https://github.com/microsoft/pyright/issues/10085
687+ if TYPE_CHECKING : # pragma: nocover
688+ assert _compat .is_array_api_obj (x1 )
689+
668690 return x1 [_helpers .in1d (x1 , x2 , assume_unique = True , invert = True , xp = xp )]
669691
670692
0 commit comments