88from typing import Tuple , Type
99
1010import numpy
11+ import pytest
1112from dpctl import select_default_device
1213from dpctl .tensor ._numpy_helper import AxisError
1314
@@ -979,7 +980,7 @@ def test_func(*args, **kw):
979980 return decorator
980981
981982
982- def for_dtypes (dtypes , name = "dtype" ):
983+ def for_dtypes (dtypes , name = "dtype" , xfail_dtypes = None ):
983984 """Decorator for parameterized dtype test.
984985
985986 Args:
@@ -1010,7 +1011,11 @@ def test_func(*args, **kw):
10101011
10111012 try :
10121013 kw [name ] = numpy .dtype (dtype ).type
1013- impl (* args , ** kw )
1014+ if xfail_dtypes is not None and dtype in xfail_dtypes :
1015+ impl_ = pytest .mark .xfail (impl )
1016+ else :
1017+ impl_ = impl
1018+ impl_ (* args , ** kw )
10141019 except _skip_classes as e :
10151020 print ("skipped: {} = {} ({})" .format (name , dtype , e ))
10161021 except Exception :
@@ -1041,19 +1046,47 @@ def _get_supported_complex_dtypes():
10411046
10421047
10431048def _get_int_dtypes ():
1044- if config .all_types :
1049+ if config .all_int_types :
10451050 return _signed_dtypes + _unsigned_dtypes
10461051 else :
10471052 return (numpy .int64 , numpy .int32 )
10481053
10491054
1055+ def _get_float_dtypes ():
1056+ if config .float16_types :
1057+ return _regular_float_dtypes + (numpy .float16 ,)
1058+ else :
1059+ return _regular_float_dtypes
1060+
1061+
1062+ def _get_signed_dtypes ():
1063+ if config .all_int_types :
1064+ return tuple (numpy .dtype (i ).type for i in "bhilq" )
1065+ else :
1066+ return (numpy .int32 ,)
1067+
1068+
1069+ def _get_unsigned_dtypes ():
1070+ if config .all_int_types :
1071+ return tuple (numpy .dtype (i ).type for i in "BHILQ" )
1072+ else :
1073+ return (numpy .uint32 ,)
1074+
1075+
1076+ def _get_int_bool_dtypes ():
1077+ if config .bool_types :
1078+ return _int_dtypes + (numpy .bool_ ,)
1079+ else :
1080+ return _int_dtypes
1081+
1082+
10501083_complex_dtypes = _get_supported_complex_dtypes ()
10511084_regular_float_dtypes = _get_supported_float_dtypes ()
1052- _float_dtypes = _regular_float_dtypes # + (numpy.float16, )
1053- _signed_dtypes = tuple ( numpy . dtype ( i ). type for i in "bhilq" )
1054- _unsigned_dtypes = tuple ( numpy . dtype ( i ). type for i in "BHILQ" )
1085+ _float_dtypes = _get_float_dtypes ( )
1086+ _signed_dtypes = _get_signed_dtypes ( )
1087+ _unsigned_dtypes = _get_unsigned_dtypes ( )
10551088_int_dtypes = _get_int_dtypes ()
1056- _int_bool_dtypes = _int_dtypes + ( numpy . bool_ , )
1089+ _int_bool_dtypes = _get_int_bool_dtypes ( )
10571090_regular_dtypes = _regular_float_dtypes + _int_bool_dtypes
10581091_dtypes = _float_dtypes + _int_bool_dtypes
10591092
@@ -1069,7 +1102,7 @@ def _make_all_dtypes(no_float16, no_bool, no_complex):
10691102 else :
10701103 dtypes += _int_bool_dtypes
10711104
1072- if not no_complex :
1105+ if config . complex_types and not no_complex :
10731106 dtypes += _complex_dtypes
10741107
10751108 return dtypes
0 commit comments