Skip to content

Commit 9bb6366

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Allow more output storage types for some dot algorithms.
As reported in jax-ml#24794, there were some dot products that were resulting in an unnecessary conversion. This change makes the output storage type selection more flexible. Fixes jax-ml#24794 PiperOrigin-RevId: 695694179
1 parent 837bccc commit 9bb6366

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

jax/_src/lax/lax.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -906,16 +906,25 @@ def rhs_precision_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
906906
return self.lhs_precision_type
907907

908908
@property
909-
def accumulation_type(self) -> DTypeLike | None:
909+
def accumulation_type(self) -> DTypeLike | tuple[DTypeLike, ...] | None:
910910
match self:
911911
case (
912912
DotAlgorithmPreset.DEFAULT |
913913
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY |
914914
DotAlgorithmPreset.ANY_F8_ANY_F8_ANY_FAST_ACCUM
915915
):
916916
return None
917+
case (
918+
DotAlgorithmPreset.ANY_F8_ANY_F8_F32 |
919+
DotAlgorithmPreset.ANY_F8_ANY_F8_F32_FAST_ACCUM
920+
):
921+
return (np.float32, np.float16, dtypes.bfloat16, dtypes.float8_e4m3fn,
922+
dtypes.float8_e5m2, dtypes.float8_e5m2fnuz,
923+
dtypes.float8_e4m3fnuz, dtypes.float8_e4m3b11fnuz)
917924
case DotAlgorithmPreset.F16_F16_F16:
918925
return np.float16
926+
case DotAlgorithmPreset.F16_F16_F32:
927+
return (np.float32, np.float16)
919928
case DotAlgorithmPreset.BF16_BF16_BF16:
920929
return dtypes.bfloat16
921930
case DotAlgorithmPreset.F64_F64_F64:
@@ -3619,6 +3628,37 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
36193628
return precision._convert_to_hlo_attr(lhs_dtype, rhs_dtype)
36203629

36213630

3631+
def get_algorithm_compute_types(
3632+
algorithm: DotAlgorithm | DotAlgorithmPreset,
3633+
lhs_dtype: DTypeLike,
3634+
rhs_dtype: DTypeLike,
3635+
out_dtype: DTypeLike | None = None,
3636+
) -> tuple[DTypeLike | None, DTypeLike | None, DTypeLike | None]:
3637+
def maybe_convert_dtype(input_dtype, target_dtype):
3638+
if target_dtype is None:
3639+
return input_dtype
3640+
if not isinstance(target_dtype, tuple):
3641+
target_dtype = (target_dtype,)
3642+
if any(input_dtype == d for d in target_dtype):
3643+
return input_dtype
3644+
return target_dtype[0]
3645+
if algorithm == DotAlgorithmPreset.BF16_BF16_F32:
3646+
lhs_dtype = maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type)
3647+
rhs_dtype = maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type)
3648+
if lhs_dtype == dtypes.bfloat16:
3649+
out_dtype = maybe_convert_dtype(out_dtype,
3650+
(np.float32, dtypes.bfloat16))
3651+
else:
3652+
out_dtype = maybe_convert_dtype(out_dtype, np.float32)
3653+
return lhs_dtype, rhs_dtype, out_dtype
3654+
else:
3655+
return (
3656+
maybe_convert_dtype(lhs_dtype, algorithm.lhs_precision_type),
3657+
maybe_convert_dtype(rhs_dtype, algorithm.rhs_precision_type),
3658+
maybe_convert_dtype(out_dtype, algorithm.accumulation_type),
3659+
)
3660+
3661+
36223662
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
36233663
precision, preferred_element_type: np.dtype | None,
36243664
out_type, platform: str = "default"):
@@ -3656,20 +3696,17 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
36563696
# If an explicit algorithm was specified, we always cast the input types to
36573697
# the correct types.
36583698
def maybe_convert_dtype(operand, operand_aval, target_dtype):
3659-
if target_dtype is None:
3660-
return operand, operand_aval.dtype
3661-
if not isinstance(target_dtype, tuple):
3662-
target_dtype = (target_dtype,)
3663-
if any(operand_aval.dtype == d for d in target_dtype):
3664-
return operand, operand_aval.dtype
3665-
aval = core.ShapedArray(operand_aval.shape, target_dtype[0])
3666-
return mlir.convert_hlo(ctx, operand, operand_aval, aval), target_dtype[0]
3667-
3668-
lhs, lhs_dtype = maybe_convert_dtype(lhs, lhs_aval, precision.lhs_precision_type)
3669-
rhs, rhs_dtype = maybe_convert_dtype(rhs, rhs_aval, precision.rhs_precision_type)
3670-
accumulation_type = precision.accumulation_type
3671-
if accumulation_type is not None:
3672-
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_type)
3699+
if target_dtype is None or operand_aval.dtype == target_dtype:
3700+
return operand
3701+
aval = core.ShapedArray(operand_aval.shape, target_dtype)
3702+
return mlir.convert_hlo(ctx, operand, operand_aval, aval)
3703+
3704+
lhs_dtype, rhs_dtype, accumulation_dtype = get_algorithm_compute_types(
3705+
precision, lhs_dtype, rhs_dtype, aval_out.dtype)
3706+
lhs = maybe_convert_dtype(lhs, lhs_aval, lhs_dtype)
3707+
rhs = maybe_convert_dtype(rhs, rhs_aval, rhs_dtype)
3708+
if accumulation_dtype is not None:
3709+
accumulation_aval = core.ShapedArray(aval_out.shape, accumulation_dtype)
36733710

36743711
if precision != DotAlgorithmPreset.DEFAULT:
36753712
algorithm_kwarg = {
@@ -3690,15 +3727,13 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36903727
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
36913728
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
36923729
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
3693-
lhs_dtype = rhs_dtype = aval_out.dtype
36943730
else: # cpu and gpu
36953731
# Do not convert mixed fp8 types to output type.
36963732
if not _is_fp8_mixed_precision_matmul(lhs_dtype, rhs_dtype):
36973733
lhs = mlir.convert_hlo(ctx, lhs, lhs_aval,
36983734
core.ShapedArray(lhs_aval.shape, aval_out.dtype))
36993735
rhs = mlir.convert_hlo(ctx, rhs, rhs_aval,
37003736
core.ShapedArray(rhs_aval.shape, aval_out.dtype))
3701-
lhs_dtype = rhs_dtype = aval_out.dtype
37023737

37033738
result = hlo.dot_general(
37043739
mlir.aval_to_ir_type(accumulation_aval),

tests/lax_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,19 @@ def fun(lhs, rhs):
11461146
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
11471147
self.assertEqual(fun(lhs, rhs).dtype, np.float16)
11481148

1149+
def testDotAlgorithmAllowedOutputStorage(self):
1150+
# see https://github.com/jax-ml/jax/issues/24794
1151+
if not jtu.test_device_matches(["gpu"]):
1152+
self.skipTest("Only supported on GPU.")
1153+
def fun(lhs, rhs):
1154+
return lax.dot(lhs, rhs, precision="F16_F16_F32",
1155+
preferred_element_type=np.float16)
1156+
lhs_shape = (3, 4)
1157+
rhs_shape = (4, 3)
1158+
rng = jtu.rand_default(self.rng())
1159+
lhs, rhs = rng(lhs_shape, np.float16), rng(rhs_shape, np.float16)
1160+
self.assertNotIn("convert", jax.jit(fun).lower(lhs, rhs).as_text())
1161+
11491162
def testDotAlgorithmConfig(self):
11501163
lhs_shape = (3, 4)
11511164
rhs_shape = (4, 3)

0 commit comments

Comments
 (0)