File tree Expand file tree Collapse file tree 5 files changed +5
-0
lines changed Expand file tree Collapse file tree 5 files changed +5
-0
lines changed Original file line number Diff line number Diff line change 11
11
import numpy as np
12
12
import math
13
13
import einops
14
+ jax .config .update ("jax_default_matmul_precision" , "highest" )
14
15
15
16
from flash_attn_jax import flash_mha
16
17
from .ref_mha import ref_mha
Original file line number Diff line number Diff line change 12
12
import numpy as np
13
13
import math
14
14
import einops
15
+ jax .config .update ("jax_default_matmul_precision" , "highest" )
15
16
16
17
from flash_attn_jax import flash_mha
17
18
from .ref_mha import ref_mha
Original file line number Diff line number Diff line change 17
17
from functools import partial
18
18
import einops
19
19
import math
20
+ jax .config .update ("jax_default_matmul_precision" , "highest" )
20
21
21
22
from flash_attn_jax .ring_attention import ring_fwd , ring_bwd
22
23
from .ref_mha import ref_fwd , ref_bwd
Original file line number Diff line number Diff line change 14
14
from jax .sharding import PartitionSpec as P
15
15
from jax .tree_util import tree_map
16
16
jax .config .update ("jax_traceback_filtering" , "off" )
17
+ jax .config .update ("jax_default_matmul_precision" , "highest" )
17
18
18
19
from flash_attn_jax import flash_mha
19
20
from .ref_mha import ref_mha
Original file line number Diff line number Diff line change 12
12
import numpy as np
13
13
import math
14
14
import einops
15
+ jax .config .update ("jax_default_matmul_precision" , "highest" )
15
16
16
17
from flash_attn_jax import flash_mha
17
18
from flash_attn_jax .varlen import flash_mha_varlen
You can’t perform that action at this time.
0 commit comments