Skip to content

Commit cd578d9

Browse files
authored
Fix jnp.matmul return shape documentation
If e.g. a.shape = (2, 3, 5, 7, 11) and b.shape = (2, 3, 5, 11, 13), then the output shape = (2, 3, 5, 7, 13)
1 parent 47858c4 commit cd578d9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9076,7 +9076,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *,
90769076
90779077
Returns:
90789078
array containing the matrix product of the inputs. Shape is ``a.shape[:-1]``
9079-
if ``b.ndim == 1``, otherwise the shape is ``(..., M)``, where leading
9079+
if ``b.ndim == 1``, otherwise the shape is ``(..., K, M)``, where leading
90809080
dimensions of ``a`` and ``b`` are broadcast together.
90819081
90829082
See Also:

0 commit comments

Comments
 (0)