Skip to content

Commit 73962b7

Browse files
Reverts a54319e
PiperOrigin-RevId: 702405512
1 parent c4d19ca commit 73962b7

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

jax/_src/lax/lax.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -923,15 +923,8 @@ def accumulation_type(self) -> DTypeLike | None:
923923
case _:
924924
return np.float32
925925

926-
def supported_output_types(
927-
self, lhs_dtype: DTypeLike, rhs_dtype: DTypeLike
928-
) -> tuple[DTypeLike, ...] | None:
929-
if np.dtype(lhs_dtype) != np.dtype(rhs_dtype):
930-
raise ValueError(
931-
f"The dot algorithm '{self}' requires both inputs to have the same "
932-
f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.'
933-
)
934-
926+
@property
927+
def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
935928
match self:
936929
case (
937930
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
@@ -949,11 +942,6 @@ def supported_output_types(
949942
)
950943
case DotAlgorithmPreset.F16_F16_F32:
951944
return (np.float32, np.float16)
952-
case DotAlgorithmPreset.BF16_BF16_F32:
953-
if np.dtype(lhs_dtype) == dtypes.bfloat16:
954-
return (np.float32, dtypes.bfloat16)
955-
else:
956-
return (np.float32,)
957945
case _:
958946
accumulation_type = self.accumulation_type
959947
return None if accumulation_type is None else (accumulation_type,)
@@ -3725,19 +3713,25 @@ def get_algorithm_compute_types(
37253713
algorithm.accumulation_type,
37263714
)
37273715

3716+
supported_output_types = algorithm.supported_output_types
3717+
3718+
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
3719+
# If dtype is anything other than float32, it will be cast to bfloat16.
3720+
if np.dtype(lhs_dtype) != np.float32:
3721+
supported_output_types = (np.float32, dtypes.bfloat16)
3722+
37283723
def maybe_convert_dtype(input_dtype, target_dtypes):
37293724
if target_dtypes is None:
37303725
return input_dtype
37313726
if np.dtype(input_dtype) in map(np.dtype, target_dtypes):
37323727
return input_dtype
37333728
return target_dtypes[0]
37343729

3735-
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types)
3736-
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types)
3737-
out_type = maybe_convert_dtype(
3738-
out_dtype, algorithm.supported_output_types(lhs_dtype, rhs_dtype)
3730+
return (
3731+
maybe_convert_dtype(lhs_dtype, algorithm.supported_lhs_types),
3732+
maybe_convert_dtype(rhs_dtype, algorithm.supported_rhs_types),
3733+
maybe_convert_dtype(out_dtype, supported_output_types),
37393734
)
3740-
return lhs_dtype, rhs_dtype, out_type
37413735

37423736

37433737
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,

0 commit comments

Comments
 (0)