1
1
from __future__ import annotations
2
2
3
3
from functools import wraps
4
- from builtins import all as builtin_all
4
+ from builtins import all as builtin_all , any as builtin_any
5
5
6
6
from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
7
7
UniqueInverseResult ,
19
19
20
20
array = torch .Tensor
21
21
22
- _array_api_dtypes = {
23
- torch .bool ,
22
+ _int_dtypes = {
24
23
torch .uint8 ,
25
24
torch .int8 ,
26
25
torch .int16 ,
27
26
torch .int32 ,
28
27
torch .int64 ,
28
+ }
29
+
30
+ _array_api_dtypes = {
31
+ torch .bool ,
32
+ * _int_dtypes ,
29
33
torch .float32 ,
30
34
torch .float64 ,
31
35
}
@@ -611,6 +615,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
611
615
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
612
616
return torch .tensordot (x1 , x2 , dims = axes , ** kwargs )
613
617
618
+
619
+ def isdtype (
620
+ dtype : Dtype , kind : Union [Dtype , str , Tuple [Union [Dtype , str ], ...]],
621
+ * , _tuple = True , # Disallow nested tuples
622
+ ) -> bool :
623
+ """
624
+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
625
+
626
+ Note that outside of this function, this compat library does not yet fully
627
+ support complex numbers.
628
+
629
+ See
630
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
631
+ for more details
632
+ """
633
+ if isinstance (kind , tuple ) and _tuple :
634
+ return builtin_any (isdtype (dtype , k , _tuple = False ) for k in kind )
635
+ elif isinstance (kind , str ):
636
+ if kind == 'bool' :
637
+ return dtype == torch .bool
638
+ elif kind == 'signed integer' :
639
+ return dtype in _int_dtypes and dtype .is_signed
640
+ elif kind == 'unsigned integer' :
641
+ return dtype in _int_dtypes and not dtype .is_signed
642
+ elif kind == 'integral' :
643
+ return dtype in _int_dtypes
644
+ elif kind == 'real floating' :
645
+ return dtype .is_floating_point
646
+ elif kind == 'complex floating' :
647
+ return dtype .is_complex
648
+ elif kind == 'numeric' :
649
+ return isdtype (dtype , ('integral' , 'real floating' , 'complex floating' ))
650
+ else :
651
+ raise ValueError (f"Unrecognized data type kind: { kind !r} " )
652
+ else :
653
+ return dtype == kind
654
+
614
655
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
615
656
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
616
657
'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -622,4 +663,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
622
663
'zeros' , 'empty' , 'tril' , 'triu' , 'expand_dims' , 'astype' ,
623
664
'broadcast_arrays' , 'unique_all' , 'unique_counts' ,
624
665
'unique_inverse' , 'unique_values' , 'matmul' , 'matrix_transpose' ,
625
- 'vecdot' , 'tensordot' ]
666
+ 'vecdot' , 'tensordot' , 'isdtype' ]
0 commit comments