Skip to content

Commit 6b33570

Browse files
committed
add kron delegate version.
1 parent 0d77147 commit 6b33570

File tree

3 files changed

+102
-81
lines changed

3 files changed

+102
-81
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
expand_dims,
99
isclose,
1010
isin,
11+
kron,
1112
nan_to_num,
1213
one_hot,
1314
pad,
@@ -20,13 +21,11 @@
2021
apply_where,
2122
broadcast_shapes,
2223
default_dtype,
23-
kron,
2424
nunique,
2525
)
2626
from ._lib._lazy import lazy_apply
2727

2828
__version__ = "0.9.1.dev0"
29-
3029
# pylint: disable=duplicate-code
3130
__all__ = [
3231
"__version__",

src/array_api_extra/_delegation.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
"create_diagonal",
2525
"expand_dims",
2626
"isclose",
27+
"kron",
2728
"nan_to_num",
2829
"one_hot",
2930
"pad",
@@ -416,6 +417,101 @@ def isclose(
416417
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
417418

418419

420+
def kron(
421+
a: Array | complex,
422+
b: Array | complex,
423+
/,
424+
*,
425+
xp: ModuleType | None = None,
426+
) -> Array:
427+
"""
428+
Kronecker product of two arrays.
429+
430+
Computes the Kronecker product, a composite array made of blocks of the
431+
second array scaled by the first.
432+
433+
Equivalent to ``numpy.kron`` for NumPy arrays.
434+
435+
Parameters
436+
----------
437+
a, b : Array | int | float | complex
438+
Input arrays or scalars. At least one must be an array.
439+
xp : array_namespace, optional
440+
The standard-compatible namespace for `a` and `b`. Default: infer.
441+
442+
Returns
443+
-------
444+
array
445+
The Kronecker product of `a` and `b`.
446+
447+
Notes
448+
-----
449+
The function assumes that the number of dimensions of `a` and `b`
450+
are the same, if necessary prepending the smallest with ones.
451+
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
452+
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
453+
The elements are products of elements from `a` and `b`, organized
454+
explicitly by::
455+
456+
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
457+
458+
where::
459+
460+
kt = it * st + jt, t = 0,...,N
461+
462+
In the common 2-D case (N=1), the block structure can be visualized::
463+
464+
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
465+
[ ... ... ],
466+
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
467+
468+
Examples
469+
--------
470+
>>> import array_api_strict as xp
471+
>>> import array_api_extra as xpx
472+
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
473+
Array([ 5, 6, 7, 50, 60, 70, 500,
474+
600, 700], dtype=array_api_strict.int64)
475+
476+
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
477+
Array([ 5, 50, 500, 6, 60, 600, 7,
478+
70, 700], dtype=array_api_strict.int64)
479+
480+
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
481+
Array([[1., 1., 0., 0.],
482+
[1., 1., 0., 0.],
483+
[0., 0., 1., 1.],
484+
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
485+
486+
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
487+
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
488+
>>> c = xpx.kron(a, b, xp=xp)
489+
>>> c.shape
490+
(2, 10, 6, 20)
491+
>>> I = (1, 3, 0, 2)
492+
>>> J = (0, 2, 1)
493+
>>> J1 = (0,) + J # extend to ndim=4
494+
>>> S1 = (1,) + b.shape
495+
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
496+
>>> c[K] == a[I]*b[J]
497+
Array(True, dtype=array_api_strict.bool)
498+
"""
499+
if xp is None:
500+
xp = array_namespace(a, b)
501+
502+
a, b = asarrays(a, b, xp=xp)
503+
504+
if (
505+
is_cupy_namespace(xp)
506+
or is_jax_namespace(xp)
507+
or is_numpy_namespace(xp)
508+
or is_torch_namespace(xp)
509+
):
510+
return xp.kron(a, b)
511+
512+
return _funcs.kron(a, b, xp=xp)
513+
514+
419515
def nan_to_num(
420516
x: Array | float | complex,
421517
/,

src/array_api_extra/_lib/_funcs.py

Lines changed: 5 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -407,87 +407,13 @@ def isclose(
407407

408408

409409
def kron(
410-
a: Array | complex,
411-
b: Array | complex,
410+
a: Array,
411+
b: Array,
412412
/,
413413
*,
414-
xp: ModuleType | None = None,
415-
) -> Array:
416-
"""
417-
Kronecker product of two arrays.
418-
419-
Computes the Kronecker product, a composite array made of blocks of the
420-
second array scaled by the first.
421-
422-
Equivalent to ``numpy.kron`` for NumPy arrays.
423-
424-
Parameters
425-
----------
426-
a, b : Array | int | float | complex
427-
Input arrays or scalars. At least one must be an array.
428-
xp : array_namespace, optional
429-
The standard-compatible namespace for `a` and `b`. Default: infer.
430-
431-
Returns
432-
-------
433-
array
434-
The Kronecker product of `a` and `b`.
435-
436-
Notes
437-
-----
438-
The function assumes that the number of dimensions of `a` and `b`
439-
are the same, if necessary prepending the smallest with ones.
440-
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
441-
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
442-
The elements are products of elements from `a` and `b`, organized
443-
explicitly by::
444-
445-
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
446-
447-
where::
448-
449-
kt = it * st + jt, t = 0,...,N
450-
451-
In the common 2-D case (N=1), the block structure can be visualized::
452-
453-
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
454-
[ ... ... ],
455-
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
456-
457-
Examples
458-
--------
459-
>>> import array_api_strict as xp
460-
>>> import array_api_extra as xpx
461-
>>> xpx.kron(xp.asarray([1, 10, 100]), xp.asarray([5, 6, 7]), xp=xp)
462-
Array([ 5, 6, 7, 50, 60, 70, 500,
463-
600, 700], dtype=array_api_strict.int64)
464-
465-
>>> xpx.kron(xp.asarray([5, 6, 7]), xp.asarray([1, 10, 100]), xp=xp)
466-
Array([ 5, 50, 500, 6, 60, 600, 7,
467-
70, 700], dtype=array_api_strict.int64)
468-
469-
>>> xpx.kron(xp.eye(2), xp.ones((2, 2)), xp=xp)
470-
Array([[1., 1., 0., 0.],
471-
[1., 1., 0., 0.],
472-
[0., 0., 1., 1.],
473-
[0., 0., 1., 1.]], dtype=array_api_strict.float64)
474-
475-
>>> a = xp.reshape(xp.arange(100), (2, 5, 2, 5))
476-
>>> b = xp.reshape(xp.arange(24), (2, 3, 4))
477-
>>> c = xpx.kron(a, b, xp=xp)
478-
>>> c.shape
479-
(2, 10, 6, 20)
480-
>>> I = (1, 3, 0, 2)
481-
>>> J = (0, 2, 1)
482-
>>> J1 = (0,) + J # extend to ndim=4
483-
>>> S1 = (1,) + b.shape
484-
>>> K = tuple(xp.asarray(I) * xp.asarray(S1) + xp.asarray(J1))
485-
>>> c[K] == a[I]*b[J]
486-
Array(True, dtype=array_api_strict.bool)
487-
"""
488-
if xp is None:
489-
xp = array_namespace(a, b)
490-
a, b = asarrays(a, b, xp=xp)
414+
xp: ModuleType,
415+
) -> Array: # numpydoc ignore=PR01,RT01
416+
"""See docstring in array_api_extra._delegation."""
491417

492418
singletons = (1,) * (b.ndim - a.ndim)
493419
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))

0 commit comments

Comments
 (0)