Skip to content

Commit 569c2a3

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Reverts 73962b7
PiperOrigin-RevId: 703100851
1 parent 39d73a6 commit 569c2a3

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

jax/_src/lax/lax.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -923,8 +923,9 @@ def accumulation_type(self) -> DTypeLike | None:
923923
case _:
924924
return np.float32
925925

926-
@property
927-
def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
926+
def supported_output_types(
927+
self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike
928+
) -> tuple[DTypeLike, ...] | None:
928929
match self:
929930
case (
930931
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
@@ -941,7 +942,17 @@ def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
941942
dtypes.float8_e4m3b11fnuz,
942943
)
943944
case DotAlgorithmPreset.F16_F16_F32:
944-
return (np.float32, np.float16)
945+
# F16 output is only supported with F16 inputs.
946+
if dtypes.promote_types(lhs_dtype, rhs_dtype) == np.float16:
947+
return (np.float32, np.float16)
948+
else:
949+
return (np.float32,)
950+
case DotAlgorithmPreset.BF16_BF16_F32:
951+
# BF16 output is only supported with BF16 inputs.
952+
if dtypes.promote_types(lhs_dtype, rhs_dtype) == dtypes.bfloat16:
953+
return (np.float32, dtypes.bfloat16)
954+
else:
955+
return (np.float32,)
945956
case _:
946957
accumulation_type = self.accumulation_type
947958
return None if accumulation_type is None else (accumulation_type,)
@@ -3713,25 +3724,19 @@ def get_algorithm_compute_types(
37133724
algorithm.accumulation_type,
37143725
)
37153726

3716-
supported_output_types = algorithm.supported_output_types
3717-
3718-
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
3719-
# If dtype is anything other than float32, it will be cast to bfloat16.
3720-
if np.dtype(lhs_dtype) != np.float32:
3721-
supported_output_types = (np.float32, dtypes.bfloat16)
3722-
37233727
def maybe_convert_dtype(input_dtype, target_dtypes):
37243728
if target_dtypes is None:
37253729
return input_dtype
37263730
if np.dtype(input_dtype) in map(np.dtype, target_dtypes):
37273731
return input_dtype
37283732
return target_dtypes[0]
37293733

3730-
return (
3731-
maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types),
3732-
maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types),
3733-
maybe_convert_dtype(out_dtype, supported_output_types),
3734+
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types)
3735+
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types)
3736+
out_type = maybe_convert_dtype(
3737+
out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype)
37343738
)
3739+
return lhs_dtype, rhs_dtype, out_type
37353740

37363741

37373742
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,

0 commit comments

Comments
 (0)