Skip to content

Commit 1f4d184

Browse files
dfmGoogle-ML-Automation
authored andcommitted
Temporarily allow bfloat16 dot algorithms on CPU.
Since XLA:CPU doesn't (yet!) support explicit algorithms for controlling the precision of dot products we have a check in JAX that fails when a non-trivial algorithm is specified on CPU. In order to support downstream use cases, this change allows some bfloat16 algorithms to pass through. XLA:CPU "emulates" these algorithms using `F32_F32_F32` with the appropriate casting, so that means that CPU numerics will be different than on other platforms with explicit algorithm support, but it is useful to be able to use these algorithms with the correct input and output casting without requiring platform dependent logic in user code. PiperOrigin-RevId: 703834889
1 parent 861115a commit 1f4d184

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

jax/_src/lax/lax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3773,6 +3773,8 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
37733773
if platform == "cpu" and precision not in {
37743774
DotAlgorithmPreset.DEFAULT, DotAlgorithmPreset.F16_F16_F16,
37753775
DotAlgorithmPreset.F32_F32_F32, DotAlgorithmPreset.F64_F64_F64,
3776+
DotAlgorithmPreset.BF16_BF16_F32, DotAlgorithmPreset.BF16_BF16_F32_X3,
3777+
DotAlgorithmPreset.BF16_BF16_F32_X6,
37763778
}:
37773779
raise ValueError(
37783780
f"The precision '{precision}' is not supported by dot_general on CPU")

tests/lax_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,9 @@ def testDotAlgorithm(self, algorithm, dtype):
10821082
lax.DotAlgorithmPreset.F16_F16_F16,
10831083
lax.DotAlgorithmPreset.F32_F32_F32,
10841084
lax.DotAlgorithmPreset.F64_F64_F64,
1085+
lax.DotAlgorithmPreset.BF16_BF16_F32,
1086+
lax.DotAlgorithmPreset.BF16_BF16_F32_X3,
1087+
lax.DotAlgorithmPreset.BF16_BF16_F32_X6,
10851088
}:
10861089
raise SkipTest(
10871090
f"The dot algorithm '{algorithm}' is not supported on CPU.")

0 commit comments

Comments
 (0)