diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index 10ce202435d7..c566c6c28612 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -269,47 +269,58 @@ def testPromoteDtypesStrict(self): @jax.numpy_dtype_promotion('standard') def testPromoteDtypesStandard(self): + assertTypePromotionError = functools.partial( + self.assertRaisesRegex, + dtypes.TypePromotionError, + 'Input dtypes .* have no available implicit dtype promotion path.', + dtypes.promote_types, + ) + + small_fp_dtypes = set(fp8_dtypes + fp4_dtypes) + implicit_int_dtypes = set(signed_dtypes + unsigned_dtypes) - set(intn_dtypes) + for t1 in all_dtypes: self.assertEqual(t1, dtypes.promote_types(t1, t1)) - self.assertEqual(t1, dtypes.promote_types(t1, np.bool_)) # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t1 in fp8_dtypes: - continue - if t1 in intn_dtypes: - continue - if t1 in fp4_dtypes: - continue - self.assertEqual(np.dtype(np.complex128), - dtypes.promote_types(t1, np.complex128)) + if t1 in small_fp_dtypes or t1 in intn_dtypes: + assertTypePromotionError(t1, np.complex128) + else: + self.assertEqual( + np.dtype(np.complex128), dtypes.promote_types(t1, np.complex128) + ) for t2 in all_dtypes: # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t2 in fp8_dtypes: - continue - if t2 in intn_dtypes: - continue - if t2 in fp4_dtypes: - continue - # Symmetry - self.assertEqual(dtypes.promote_types(t1, t2), - dtypes.promote_types(t2, t1)) + if ( + (t1 != t2) + and (t1 != np.bool_) + and (t2 != np.bool_) + and ( + t1 in intn_dtypes or + t2 in intn_dtypes or + (t1 in small_fp_dtypes and t2 not in implicit_int_dtypes) or + (t2 in small_fp_dtypes and t1 not in implicit_int_dtypes) + ) + ): + assertTypePromotionError(t1, t2) + assertTypePromotionError(t2, t1) + else: + self.assertEqual( + dtypes.promote_types(t1, t2), dtypes.promote_types(t2, t1) + ) self.assertEqual(np.dtype(np.float32), dtypes.promote_types(np.float16, dtypes.bfloat16)) - # Promotions of non-inexact types against inexact types always prefer - # the inexact types. + # Promotions of exact types against inexact types always prefer the + # inexact types. for t in float_dtypes + complex_dtypes: for i in bool_dtypes + signed_dtypes + unsigned_dtypes: - # TODO(zhangqiaorjc): Consider more dtype promotion rules for fp8. - if t in fp8_dtypes: - continue - if t in fp4_dtypes: - continue - if t in intn_dtypes or i in intn_dtypes: - continue - self.assertEqual(t, dtypes.promote_types(t, i)) + if i in intn_dtypes: + assertTypePromotionError(t, i) + else: + self.assertEqual(t, dtypes.promote_types(t, i)) # Promotions between exact types, or between inexact types, match NumPy. for groups in [bool_dtypes + np_signed_dtypes + np_unsigned_dtypes,