Skip to content

Commit feb75bc

Browse files
committed
docs: move notes below the fold and highlight RFC 2119 keywords
Ref: #397
1 parent 3267c18 commit feb75bc

File tree

1 file changed

+49
-49
lines changed

1 file changed

+49
-49
lines changed

src/array_api_stubs/_draft/linear_algebra_functions.py

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,53 @@ def matmul(x1: array, x2: array, /) -> array:
88
"""
99
Computes the matrix product.
1010
11-
.. note::
12-
The ``matmul`` function must implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
13-
1411
Parameters
1512
----------
1613
x1: array
17-
first input array. Should have a numeric data type. Must have at least one dimension. If ``x1`` is one-dimensional having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` must be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., must have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array must be removed. If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` must be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
18-
x2: array
19-
second input array. Should have a numeric data type. Must have at least one dimension. If ``x2`` is one-dimensional having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` must be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., must have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array must be removed. If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` must be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`). If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
14+
first input array. **Should** have a numeric data type. **Must** have at least one dimension.
2015
16+
- If ``x1`` is a one-dimensional array having shape ``(M,)`` and ``x2`` has more than one dimension, ``x1`` **must** be promoted to a two-dimensional array by prepending ``1`` to its dimensions (i.e., **must** have shape ``(1, M)``). After matrix multiplication, the prepended dimensions in the returned array **must** be removed.
17+
- If ``x1`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x1)[:-2]`` **must** be compatible with ``shape(x2)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
18+
- If ``x1`` has shape ``(..., M, K)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
19+
20+
x2: array
21+
second input array. **Should** have a numeric data type. **Must** have at least one dimension.
2122
22-
.. note::
23-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the matrix product.
23+
- If ``x2`` is one-dimensional array having shape ``(N,)`` and ``x1`` has more than one dimension, ``x2`` **must** be promoted to a two-dimensional array by appending ``1`` to its dimensions (i.e., **must** have shape ``(N, 1)``). After matrix multiplication, the appended dimensions in the returned array **must** be removed.
24+
- If ``x2`` has more than one dimension (including after vector-to-matrix promotion), ``shape(x2)[:-2]`` **must** be compatible with ``shape(x1)[:-2]`` (after vector-to-matrix promotion) (see :ref:`broadcasting`).
25+
- If ``x2`` has shape ``(..., K, N)``, the innermost two dimensions form matrices on which to perform matrix multiplication.
2426
2527
Returns
2628
-------
2729
out: array
2830
- if both ``x1`` and ``x2`` are one-dimensional arrays having shape ``(N,)``, a zero-dimensional array containing the inner product as its only element.
2931
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, a two-dimensional array containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ and having shape ``(M, N)``.
30-
- if ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
31-
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion must be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
32+
- if ``x1`` is a one-dimensional array having shape ``(K,)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., N)`` (i.e., prepended dimensions during vector-to-matrix promotion **must** be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
33+
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a one-dimensional array having shape ``(K,)``, an array having shape ``(..., M)`` (i.e., appended dimensions during vector-to-matrix promotion **must** be removed) and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_.
3234
- if ``x1`` is a two-dimensional array having shape ``(M, K)`` and ``x2`` is an array having shape ``(..., K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
3335
- if ``x1`` is an array having shape ``(..., M, K)`` and ``x2`` is a two-dimensional array having shape ``(K, N)``, an array having shape ``(..., M, N)`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
3436
- if either ``x1`` or ``x2`` has more than two dimensions, an array having a shape determined by :ref:`broadcasting` ``shape(x1)[:-2]`` against ``shape(x2)[:-2]`` and containing the `conventional matrix product <https://en.wikipedia.org/wiki/Matrix_multiplication>`_ for each stacked matrix.
3537
36-
The returned array must have a data type determined by :ref:`type-promotion`.
37-
38-
Notes
39-
-----
40-
41-
.. versionchanged:: 2022.12
42-
Added complex data type support.
38+
The returned array **must** have a data type determined by :ref:`type-promotion`.
4339
44-
**Raises**
40+
Raises
41+
------
4542
4643
- if either ``x1`` or ``x2`` is a zero-dimensional array.
4744
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
4845
- if ``x1`` is a one-dimensional array having shape ``(K,)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
4946
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is a one-dimensional array having shape ``(L,)``, and ``K != L``.
5047
- if ``x1`` is an array having shape ``(..., M, K)``, ``x2`` is an array having shape ``(..., L, N)``, and ``K != L``.
5148
49+
Notes
50+
-----
51+
52+
- The ``matmul`` function **must** implement the same semantics as the built-in ``@`` operator (see `PEP 465 <https://www.python.org/dev/peps/pep-0465>`_).
53+
54+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or tranpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the matrix product.
55+
56+
.. versionchanged:: 2022.12
57+
Added complex data type support.
5258
"""
5359

5460

@@ -64,7 +70,7 @@ def matrix_transpose(x: array, /) -> array:
6470
Returns
6571
-------
6672
out: array
67-
an array containing the transpose for each matrix and having shape ``(..., N, M)``. The returned array must have the same data type as ``x``.
73+
an array containing the transpose for each matrix. The returned array **must** have shape ``(..., N, M)``. The returned array **must** have the same data type as ``x``.
6874
"""
6975

7076

@@ -78,42 +84,37 @@ def tensordot(
7884
"""
7985
Returns a tensor contraction of ``x1`` and ``x2`` over specific axes.
8086
81-
.. note::
82-
The ``tensordot`` function corresponds to the generalized matrix product.
83-
8487
Parameters
8588
----------
8689
x1: array
87-
first input array. Should have a numeric data type.
90+
first input array. **Should** have a numeric data type.
8891
x2: array
89-
second input array. Should have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` must be equal.
90-
91-
.. note::
92-
Contracted axes (dimensions) must not be broadcasted.
92+
second input array. **Should** have a numeric data type. Corresponding contracted axes of ``x1`` and ``x2`` **must** be equal.
9393
9494
axes: Union[int, Tuple[Sequence[int], Sequence[int]]]
95-
number of axes (dimensions) to contract or explicit sequences of axis (dimension) indices for ``x1`` and ``x2``, respectively.
96-
97-
If ``axes`` is an ``int`` equal to ``N``, then contraction must be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis (dimension) must match. Must be nonnegative.
95+
number of axes to contract or explicit sequences of axis indices for ``x1`` and ``x2``, respectively.
9896
99-
- If ``N`` equals ``0``, the result is the tensor (outer) product.
100-
- If ``N`` equals ``1``, the result is the tensor dot product.
101-
- If ``N`` equals ``2``, the result is the tensor double contraction (default).
97+
If ``axes`` is an ``int`` equal to ``N``, then contraction **must** be performed over the last ``N`` axes of ``x1`` and the first ``N`` axes of ``x2`` in order. The size of each corresponding axis **must** match. An integer ``axes`` value **must** be nonnegative.
10298
103-
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence must apply to ``x1`` and the second sequence to ``x2``. Both sequences must have the same length. Each axis (dimension) ``x1_axes[i]`` for ``x1`` must have the same size as the respective axis (dimension) ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence must be unique. If ``x1`` has rank (i.e, number of dimensions) ``N``, a valid ``x1`` axis must reside on the half-open interval ``[-N, N)``. If ``x2`` has rank ``M``, a valid ``x2`` axis must reside on the half-open interval ``[-M, M)``.
99+
- If ``N`` equals ``0``, the result **must** be the tensor (outer) product.
100+
- If ``N`` equals ``1``, the result **must** be the tensor dot product.
101+
- If ``N`` equals ``2``, the result **must** be the tensor double contraction (default).
104102
103+
If ``axes`` is a tuple of two sequences ``(x1_axes, x2_axes)``, the first sequence **must** apply to ``x1`` and the second sequence **must** apply to ``x2``. Both sequences **must** have the same length. Each axis ``x1_axes[i]`` for ``x1`` **must** have the same size as the respective axis ``x2_axes[i]`` for ``x2``. Each index referred to in a sequence **must** be unique. A valid axis **must** be an integer on the interval ``[-S, S)``, where ``S`` is the number of axes in respective array. Hence, if ``x1`` has ``N`` axes, a valid ``x1`` axes **must** be an integer on the interval ``[-N, N)``. If ``x2`` has ``M`` axes, a valid ``x2`` axes **must** be an integer on the interval ``[-M, M)``. If an axis is specified as a negative integer, the function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). If provided an invalid axis, the function **must** raise an exception.
105104
106-
.. note::
107-
If either ``x1`` or ``x2`` has a complex floating-point data type, neither argument must be complex-conjugated or transposed. If conjugation and/or transposition is desired, these operations should be explicitly performed prior to computing the generalized matrix product.
108105
109106
Returns
110107
-------
111108
out: array
112-
an array containing the tensor contraction whose shape consists of the non-contracted axes (dimensions) of the first array ``x1``, followed by the non-contracted axes (dimensions) of the second array ``x2``. The returned array must have a data type determined by :ref:`type-promotion`.
109+
an array containing the tensor contraction. The returned array **must** have a shape which consists of the non-contracted axes of the first array ``x1``, followed by the non-contracted axes of the second array ``x2``. The returned array **must** have a data type determined by :ref:`type-promotion`.
113110
114111
Notes
115112
-----
116113
114+
- The ``tensordot`` function corresponds to the generalized matrix product.
115+
- Contracted axes **must** not be broadcasted.
116+
- If either ``x1`` or ``x2`` has a complex floating-point data type, the function **must not** complex-conjugate or transpose either argument. If conjugation and/or transposition is desired, a user can explicitly perform these operations prior to computing the generalized matrix product.
117+
117118
.. versionchanged:: 2022.12
118119
Added complex data type support.
119120
@@ -131,32 +132,31 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array:
131132
.. math::
132133
\mathbf{a} \cdot \mathbf{b} = \sum_{i=0}^{n-1} \overline{a_i}b_i
133134
134-
over the dimension specified by ``axis`` and where :math:`n` is the dimension size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
135+
over the axis specified by ``axis`` and where :math:`n` is the axis size and :math:`\overline{a_i}` denotes the complex conjugate if :math:`a_i` is complex and the identity if :math:`a_i` is real-valued.
135136
136137
Parameters
137138
----------
138139
x1: array
139-
first input array. Should have a floating-point data type.
140+
first input array. **Should** have a floating-point data type.
140141
x2: array
141-
second input array. Must be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product must be the same size as the respective axis in ``x1``. Should have a floating-point data type.
142-
143-
.. note::
144-
The contracted axis (dimension) must not be broadcasted.
145-
142+
second input array. **Must** be compatible with ``x1`` for all non-contracted axes (see :ref:`broadcasting`). The size of the axis over which to compute the dot product **must** be the same size as the respective axis in ``x1``. **Should** have a floating-point data type.
146143
axis: int
147-
the axis (dimension) of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. Should be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function must determine the axis along which to compute the dot product by counting backward from the last dimension (where ``-1`` refers to the last dimension). By default, the function must compute the dot product over the last axis. Default: ``-1``.
144+
axis of ``x1`` and ``x2`` containing the vectors for which to compute the dot product. **Should** be an integer on the interval ``[-N, -1]``, where ``N`` is ``min(x1.ndim, x2.ndim)``. The function **must** determine the axis along which to perform the operation by counting backward from the last axis (where ``-1`` refers to the last axis). By default, the function **must** compute the dot product over the last axis. Default: ``-1``.
148145
149146
Returns
150147
-------
151148
out: array
152-
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is the rank (number of dimensions) of the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array must have a data type determined by :ref:`type-promotion`.
149+
if ``x1`` and ``x2`` are both one-dimensional arrays, a zero-dimensional containing the dot product; otherwise, a non-zero-dimensional array containing the dot products and having rank ``N-1``, where ``N`` is number of axes in the shape determined according to :ref:`broadcasting` along the non-contracted axes. The returned array **must** have a data type determined by :ref:`type-promotion`.
150+
151+
Raises
152+
------
153+
154+
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
153155
154156
Notes
155157
-----
156158
157-
**Raises**
158-
159-
- if the size of the axis over which to compute the dot product is not the same (before broadcasting) for both ``x1`` and ``x2``.
159+
- The contracted axis **must** not be broadcasted.
160160
161161
.. versionchanged:: 2022.12
162162
Added complex data type support.

0 commit comments

Comments
 (0)