Skip to content

Commit 46c748b

Browse files
Merge pull request jax-ml#25055 from dfm:multi-dot
PiperOrigin-RevId: 702039013
2 parents 385328b + 236d4c6 commit 46c748b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/numpy/linalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2119,7 +2119,7 @@ def multi_dot(arrays: Sequence[ArrayLike], *, precision: PrecisionLike = None) -
21192119
if arrs[-1].ndim == 1:
21202120
einsum_axes[-1] = einsum_axes[-1][:1]
21212121
return jnp.einsum(*itertools.chain(*zip(arrs, einsum_axes)), # type: ignore[call-overload]
2122-
optimize='optimal', precision=precision)
2122+
optimize='auto', precision=precision)
21232123

21242124

21252125
@export

0 commit comments

Comments
 (0)