Skip to content

Commit 48da48a

Browse files
committed
Increase matmul precision in tests to make them less flaky.
1 parent 70dbcdd commit 48da48a

File tree

5 files changed

+5
-0
lines changed

5 files changed

+5
-0
lines changed

tests/test_cross.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import math
1313
import einops
14+
jax.config.update("jax_default_matmul_precision", "highest")
1415

1516
from flash_attn_jax import flash_mha
1617
from .ref_mha import ref_mha

tests/test_flash.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
import math
1414
import einops
15+
jax.config.update("jax_default_matmul_precision", "highest")
1516

1617
from flash_attn_jax import flash_mha
1718
from .ref_mha import ref_mha

tests/test_ring.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from functools import partial
1818
import einops
1919
import math
20+
jax.config.update("jax_default_matmul_precision", "highest")
2021

2122
from flash_attn_jax.ring_attention import ring_fwd, ring_bwd
2223
from .ref_mha import ref_fwd, ref_bwd

tests/test_sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from jax.sharding import PartitionSpec as P
1515
from jax.tree_util import tree_map
1616
jax.config.update("jax_traceback_filtering", "off")
17+
jax.config.update("jax_default_matmul_precision", "highest")
1718

1819
from flash_attn_jax import flash_mha
1920
from .ref_mha import ref_mha

tests/test_varlen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import numpy as np
1313
import math
1414
import einops
15+
jax.config.update("jax_default_matmul_precision", "highest")
1516

1617
from flash_attn_jax import flash_mha
1718
from flash_attn_jax.varlen import flash_mha_varlen

0 commit comments

Comments
 (0)