Skip to content

Commit 1e9e85a

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
Simplify handling of DotAlgorithmPreset output types.
Create a clear distinction between the type used for accumulation and possible output types. PiperOrigin-RevId: 698399447
1 parent a582df0 commit 1e9e85a

File tree

1 file changed

+52
-32
lines changed

1 file changed

+52
-32
lines changed

jax/_src/lax/lax.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -879,11 +879,11 @@ def __str__(self) -> str:
879879
def lhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
880880
match self:
881881
case (
882-
DotAlgorithmPreset.DEFAULT |
883-
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
884-
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
885-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
886-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
882+
DotAlgorithmPreset.DEFAULT
883+
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32
884+
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
885+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
886+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
887887
):
888888
return None
889889
case DotAlgorithmPreset.F16_F16_F16 | DotAlgorithmPreset.F16_F16_F32:
@@ -906,31 +906,38 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
906906
return self.lhs_precision_type
907907

908908
@property
909-
def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
909+
def accumulation_type(self) -> DTypeLike | None:
910910
match self:
911911
case (
912-
DotAlgorithmPreset.DEFAULT |
913-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
914-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
912+
DotAlgorithmPreset.DEFAULT
913+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
914+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
915915
):
916916
return None
917+
case DotAlgorithmPreset.F16_F16_F16:
918+
return np.float16
919+
case DotAlgorithmPreset.BF16_BF16_BF16:
920+
return dtypes.bfloat16
921+
case DotAlgorithmPreset.F64_F64_F64:
922+
return np.float64
923+
case _:
924+
return np.float32
925+
926+
@property
927+
def supported_output_types(self) -> tuple[DTypeLike, ...] | None:
928+
match self:
917929
case (
918930
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
919931
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
920932
):
921933
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
922934
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
923935
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
924-
case DotAlgorithmPreset.F16_F16_F16:
925-
return np.float16
926936
case DotAlgorithmPreset.F16_F16_F32:
927937
return (np.float32, np.float16)
928-
case DotAlgorithmPreset.BF16_BF16_BF16:
929-
return dtypes.bfloat16
930-
case DotAlgorithmPreset.F64_F64_F64:
931-
return np.float64
932938
case _:
933-
return np.float32
939+
accumulation_type = self.accumulation_type
940+
return None if accumulation_type is None else (accumulation_type,)
934941

935942
def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
936943
rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None:
@@ -941,30 +948,39 @@ def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike,
941948
tf32 = ir.FloatTF32Type.get()
942949
match self:
943950
case (
944-
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
945-
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM |
946-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
947-
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
951+
DotAlgorithmPreset.ANY_F8_ANY_F8_F32
952+
| DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
953+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY
954+
| DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
948955
):
949-
fp8_dtypes = [np.dtype(dtypes.float8_e4m3b11fnuz),
950-
np.dtype(dtypes.float8_e4m3fn),
951-
np.dtype(dtypes.float8_e4m3fnuz),
952-
np.dtype(dtypes.float8_e5m2),
953-
np.dtype(dtypes.float8_e5m2fnuz)]
956+
fp8_dtypes = [
957+
np.dtype(dtypes.float8_e4m3b11fnuz),
958+
np.dtype(dtypes.float8_e4m3fn),
959+
np.dtype(dtypes.float8_e4m3fnuz),
960+
np.dtype(dtypes.float8_e5m2),
961+
np.dtype(dtypes.float8_e5m2fnuz),
962+
]
954963
if dtypes.float8_e3m4 is not None:
955964
fp8_dtypes += [np.dtype(dtypes.float8_e3m4)]
956965
if dtypes.float8_e4m3 is not None:
957966
fp8_dtypes += [np.dtype(dtypes.float8_e4m3)]
958967
if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes:
959968
raise ValueError(
960969
f"The dot algorithm '{self}' requires both inputs to have float8 "
961-
f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.")
970+
f'dtypes. Got {lhs_dtype} and {rhs_dtype} instead.'
971+
)
962972
lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype))
963973
rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype))
964974
acc = ir.F32Type.get()
965975
return hlo.DotAlgorithm.get(
966-
lhs, rhs, acc, 1, 1, 1,
967-
self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM)
976+
lhs,
977+
rhs,
978+
acc,
979+
1,
980+
1,
981+
1,
982+
self == DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM,
983+
)
968984
case DotAlgorithmPreset.F16_F16_F16:
969985
return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False)
970986
case DotAlgorithmPreset.F16_F16_F32:
@@ -3649,9 +3665,8 @@ def maybe_convert_dtype(input_dtype, target_dtype):
36493665
return input_dtype
36503666
if not isinstance(target_dtype, tuple):
36513667
target_dtype = (target_dtype,)
3652-
if any(input_dtype == d for d in target_dtype):
3653-
return input_dtype
3654-
return target_dtype[0]
3668+
return input_dtype if input_dtype in target_dtype else target_dtype[0]
3669+
36553670
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
36563671
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
36573672
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
@@ -3662,10 +3677,15 @@ def maybe_convert_dtype(input_dtype, target_dtype):
36623677
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
36633678
return lhs_dtype, rhs_dtype, out_dtype
36643679
else:
3680+
if isinstance(algorithm, DotAlgorithmPreset):
3681+
supported_output_types = algorithm.supported_output_types
3682+
else:
3683+
supported_output_types = (algorithm.accumulation_type,)
3684+
36653685
return (
36663686
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
36673687
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
3668-
maybe_convert_dtype(out_dtype, algorithm.accumulation_type),
3688+
maybe_convert_dtype(out_dtype, supported_output_types),
36693689
)
36703690

36713691

0 commit comments

Comments
 (0)