Skip to content

Commit 02b5d47

Browse files
Google-ML-Automationjax authors
authored andcommitted
Swap operands of dot if the LHS is fed by a parameter
PiperOrigin-RevId: 642090766
1 parent 9439f63 commit 02b5d47

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/shape_poly_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,10 @@ def test_harness(self, harness: PolyHarness):
33273327
if "random_gamma" in harness.group_name:
33283328
config_flags = {**config_flags, "jax_debug_key_reuse": False}
33293329

3330+
# TPU precision is a little lower since we swap the order of matmul operands.
3331+
if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]):
3332+
harness.tol = 5e-5
3333+
33303334
with jtu.global_config_context(**config_flags):
33313335
harness.run_test(self)
33323336

0 commit comments

Comments
 (0)