Skip to content

Commit 5d5b06c

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[jax] Canonicalize dtypes when checking if dtypes present in target dtypes list.
PiperOrigin-RevId: 701961663
1 parent 7b32d88 commit 5d5b06c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jax/_src/lax/lax.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3704,12 +3704,14 @@ def maybe_convert_dtype(input_dtype, target_dtype):
37043704
return input_dtype
37053705
if not isinstance(target_dtype, tuple):
37063706
target_dtype = (target_dtype,)
3707-
return input_dtype if input_dtype in target_dtype else target_dtype[0]
3707+
if np.dtype(input_dtype) in map(np.dtype, target_dtype):
3708+
return input_dtype
3709+
return target_dtype[0]
37083710

37093711
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
37103712
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
37113713
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
3712-
if lhs_dtype == dtypes.bfloat16:
3714+
if np.dtype(lhs_dtype) == dtypes.bfloat16:
37133715
out_dtype = maybe_convert_dtype(out_dtype,
37143716
(np.float32, dtypes.bfloat16))
37153717
else:

0 commit comments

Comments
 (0)