Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@
:toctree: generated

atleast_nd
expand_dims
kron
```
44 changes: 29 additions & 15 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pylint = "*"
# import dependencies for mypy:
array-api-strict = "*"
numpy = "*"
pytest = "*"

[tool.pixi.feature.lint.tasks]
pre-commit = { cmd = "pre-commit install && pre-commit run -v --all-files --show-diff-on-failure" }
Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from ._funcs import atleast_nd
from ._funcs import atleast_nd, expand_dims, kron

__version__ = "0.1.2.dev0"

__all__ = ["__version__", "atleast_nd"]
__all__ = ["__version__", "atleast_nd", "expand_dims", "kron"]
197 changes: 195 additions & 2 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
if TYPE_CHECKING:
from ._typing import Array, ModuleType

__all__ = ["atleast_nd"]
__all__ = ["atleast_nd", "expand_dims", "kron"]


def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Expand Down Expand Up @@ -46,3 +46,196 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
x = xp.expand_dims(x, axis=0)
x = atleast_nd(x, ndim=ndim, xp=xp)
return x


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
) -> Array:
"""
Expand the shape of an array.

Insert (a) new axis/axes that will appear at the position(s) specified by
`axis` in the expanded array shape.

This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.

Parameters
----------
a : array
axis : int or tuple of ints, optional
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
If multiple positions are provided, they should be unique (note that a position
given by a positive index could also be referred to by a negative index -
that will also result in an error).
Default: ``(0,)``.
xp : array_namespace
The standard-compatible namespace for `a`.

Returns
-------
res : array
`a` with an expanded shape.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1, 2])
>>> x.shape
(2,)

The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:

>>> y = xpx.expand_dims(x, axis=0, xp=xp)
>>> y
Array([[1, 2]], dtype=array_api_strict.int64)
>>> y.shape
(1, 2)

The following is equivalent to ``x[:, xp.newaxis]``:

>>> y = xpx.expand_dims(x, axis=1, xp=xp)
>>> y
Array([[1],
[2]], dtype=array_api_strict.int64)
>>> y.shape
(2, 1)

``axis`` may also be a tuple:

>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
>>> y
Array([[[1, 2]]], dtype=array_api_strict.int64)

>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
>>> y
Array([[[1],
[2]]], dtype=array_api_strict.int64)

"""
if not isinstance(axis, tuple):
axis = (axis,)
ndim = a.ndim + len(axis)
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
err_msg = (
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
)
raise IndexError(err_msg)
axis = tuple(dim % ndim for dim in axis)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
if len(set(axis)) != len(axis):
err_msg = "Duplicate dimensions specified in `axis`."
raise ValueError(err_msg)
for i in sorted(axis):
a = xp.expand_dims(a, axis=i)
return a


def kron(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
"""
Kronecker product of two arrays.

Computes the Kronecker product, a composite array made of blocks of the
second array scaled by the first.

Equivalent to ``numpy.kron`` for NumPy arrays.

Parameters
----------
a, b : array
xp : array_namespace
The standard-compatible namespace for `a` and `b`.

Returns
-------
res : array
The Kronecker product of `a` and `b`.

Notes
-----
The function assumes that the number of dimensions of `a` and `b`
are the same, if necessary prepending the smallest with ones.
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
The elements are products of elements from `a` and `b`, organized
explicitly by::

kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]

where::

kt = it * st + jt, t = 0,...,N

In the common 2-D case (N=1), the block structure can be visualized::

[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
[ ... ... ],
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]


Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
Array([ 5, 6, 7, 50, 60, 70, 500,
600, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
Array([ 5, 50, 500, 6, 60, 600, 7,
70, 700], dtype=array_api_strict.int64)

>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
Array([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]], dtype=array_api_strict.float64)


>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
>>> c = xpx.kron(a, b, xp=xp)
>>> c.shape
(2, 10, 6, 20)
>>> I = (1, 3, 0, 2)
>>> J = (0, 2, 1)
>>> J1 = (0,) + J # extend to ndim=4
>>> S1 = (1,) + b.shape
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
>>> c[K] == a[I]*b[J]
Array(True, dtype=array_api_strict.bool)

"""

b = xp.asarray(b)
singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
if nd_a == 0 or nd_b == 0:
return xp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape

# Equalise the shapes by prepending smaller one with 1s
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
b_shape = (1,) * max(0, nd_a - nd_b) + b_shape

# Insert empty dimensions
a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp)
b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp)

# Compute the product
a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp)
b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp)
result = xp.multiply(a_arr, b_arr)

# Reshape back and return
a_shape = xp.asarray(a_shape)
b_shape = xp.asarray(b_shape)
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))
Loading