Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions tests/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,47 +269,58 @@ def testPromoteDtypesStrict(self):

@jax.numpy_dtype_promotion('standard')
def testPromoteDtypesStandard(self):
assertTypePromotionError = functools.partial(
self.assertRaisesRegex,
dtypes.TypePromotionError,
'Input dtypes .* have no available implicit dtype promotion path.',
dtypes.promote_types,
)

small_fp_dtypes = set(fp8_dtypes + fp4_dtypes)
implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes)

for t1 in all_dtypes:
self.assertEqual(t1, dtypes.promote_types(t1, t1))

self.assertEqual(t1, dtypes.promote_types(t1, np.bool_))
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
if t1 in fp8_dtypes:
continue
if t1 in intn_dtypes:
continue
if t1 in fp4_dtypes:
continue
self.assertEqual(np.dtype(np.complex128),
dtypes.promote_types(t1, np.complex128))
if t1 in small_fp_dtypes or t1 in intn_dtypes:
assertTypePromotionError(t1, np.complex128)
else:
self.assertEqual(
np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128)
)

for t2 in all_dtypes:
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
if t2 in fp8_dtypes:
continue
if t2 in intn_dtypes:
continue
if t2 in fp4_dtypes:
continue
# Symmetry
self.assertEqual(dtypes.promote_types(t1, t2),
dtypes.promote_types(t2, t1))
if (
(t1 != t2)
and (t1 != np.bool_)
and (t2 != np.bool_)
and (
t1 in intn_dtypes or
t2 in intn_dtypes or
(t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or
(t2 in small_fp_dtypes and t1 not in implicit_int_dtypes)
)
):
assertTypePromotionError(t1, t2)
assertTypePromotionError(t2, t1)
else:
self.assertEqual(
dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1)
)

self.assertEqual(np.dtype(np.float32),
dtypes.promote_types(np.float16, dtypes.bfloat16))

# Promotions of non-inexact types against inexact types always prefer
# the inexact types.
# Promotions of exact types against inexact types always prefer the
# inexact types.
for t in float_dtypes + complex_dtypes:
for i in bool_dtypes + signed_dtypes + unsigned_dtypes:
# TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8.
if t in fp8_dtypes:
continue
if t in fp4_dtypes:
continue
if t in intn_dtypes or i in intn_dtypes:
continue
self.assertEqual(t, dtypes.promote_types(t, i))
if i in intn_dtypes:
assertTypePromotionError(t, i)
else:
self.assertEqual(t, dtypes.promote_types(t, i))

# Promotions between exact types, or between inexact types, match NumPy.
for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes,
Expand Down