Skip to content

Commit 2e0474a

Browse files
Merge pull request jax-ml#25191 from houeland:patch-1
PiperOrigin-RevId: 702007643
2 parents b1423a3 + cd578d9 commit 2e0474a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
90769076
90779077
Returns:
90789078
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
9079-
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
9079+
if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading
90809080
dimensions of ``a`` and ``b`` are broadcast together.
90819081
90829082
See Also:

0 commit comments

Comments
 (0)