Skip to content

Commit e9077f5

Browse files
lucascolleymdhaber
andcommitted
ENH: add kron and expand_dims
Co-authored-by: Matt Haberland <[email protected]>
1 parent 6e596d9 commit e9077f5

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

src/array_api_extra/_funcs.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,176 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
4646
x = xp.expand_dims(x, axis=0)
4747
x = atleast_nd(x, ndim=ndim, xp=xp)
4848
return x
49+
50+
51+
def expand_dims(a: Array, *, axis: tuple[int] = (0,), xp: ModuleType):
52+
"""
53+
Expand the shape of an array.
54+
55+
Insert a new axis that will appear at the `axis` position in the expanded
56+
array shape.
57+
58+
This is ``xp.expand_dims`` for ``axis`` an int *or a tuple of ints*.
59+
Equivalent to ``numpy.expand_dims`` for NumPy arrays.
60+
61+
Parameters
62+
----------
63+
a : array
64+
axis : int or tuple of ints
65+
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
66+
xp : array_namespace
67+
The standard-compatible namespace for `a`.
68+
69+
Returns
70+
-------
71+
res : array
72+
`a` with an expanded shape.
73+
74+
Examples
75+
--------
76+
# >>> import numpy as np
77+
# >>> x = np.array([1, 2])
78+
# >>> x.shape
79+
# (2,)
80+
81+
# The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
82+
83+
# >>> y = np.expand_dims(x, axis=0)
84+
# >>> y
85+
# array([[1, 2]])
86+
# >>> y.shape
87+
# (1, 2)
88+
89+
# The following is equivalent to ``x[:, np.newaxis]``:
90+
91+
# >>> y = np.expand_dims(x, axis=1)
92+
# >>> y
93+
# array([[1],
94+
# [2]])
95+
# >>> y.shape
96+
# (2, 1)
97+
98+
# ``axis`` may also be a tuple:
99+
100+
# >>> y = np.expand_dims(x, axis=(0, 1))
101+
# >>> y
102+
# array([[[1, 2]]])
103+
104+
# >>> y = np.expand_dims(x, axis=(2, 0))
105+
# >>> y
106+
# array([[[1],
107+
# [2]]])
108+
109+
# Note that some examples may use ``None`` instead of ``np.newaxis``. These
110+
# are the same objects:
111+
112+
# >>> np.newaxis is None
113+
# True
114+
115+
"""
116+
if not isinstance(axis, tuple):
117+
axis = (axis,)
118+
for i in axis:
119+
a = xp.expand_dims(a, axis=i, xp=xp)
120+
return a
121+
122+
123+
def kron(a: Array, b: Array, *, xp: ModuleType):
124+
"""
125+
Kronecker product of two arrays.
126+
127+
Computes the Kronecker product, a composite array made of blocks of the
128+
second array scaled by the first.
129+
130+
Equivalent to ``numpy.kron`` for NumPy arrays.
131+
132+
Parameters
133+
----------
134+
a, b : array
135+
xp : array_namespace
136+
The standard-compatible namespace for `a` and `b`.
137+
138+
Returns
139+
-------
140+
res : array
141+
142+
Notes
143+
-----
144+
The function assumes that the number of dimensions of `a` and `b`
145+
are the same, if necessary prepending the smallest with ones.
146+
If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
147+
the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
148+
The elements are products of elements from `a` and `b`, organized
149+
explicitly by::
150+
151+
kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
152+
153+
where::
154+
155+
kt = it * st + jt, t = 0,...,N
156+
157+
In the common 2-D case (N=1), the block structure can be visualized::
158+
159+
[[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
160+
[ ... ... ],
161+
[ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
162+
163+
164+
Examples
165+
--------
166+
# >>> import numpy as np
167+
# >>> np.kron([1,10,100], [5,6,7])
168+
# array([ 5, 6, 7, ..., 500, 600, 700])
169+
# >>> np.kron([5,6,7], [1,10,100])
170+
# array([ 5, 50, 500, ..., 7, 70, 700])
171+
172+
# >>> np.kron(np.eye(2), np.ones((2,2)))
173+
# array([[1., 1., 0., 0.],
174+
# [1., 1., 0., 0.],
175+
# [0., 0., 1., 1.],
176+
# [0., 0., 1., 1.]])
177+
178+
# >>> a = np.arange(100).reshape((2,5,2,5))
179+
# >>> b = np.arange(24).reshape((2,3,4))
180+
# >>> c = np.kron(a,b)
181+
# >>> c.shape
182+
# (2, 10, 6, 20)
183+
# >>> I = (1,3,0,2)
184+
# >>> J = (0,2,1)
185+
# >>> J1 = (0,) + J # extend to ndim=4
186+
# >>> S1 = (1,) + b.shape
187+
# >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
188+
# >>> c[K] == a[I]*b[J]
189+
# True
190+
191+
"""
192+
193+
b = xp.asarray(b)
194+
singletons = (1,) * (b.ndim - a.ndim)
195+
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
196+
197+
nd_b, nd_a = b.ndim, a.ndim
198+
nd_max = max(nd_b, nd_a)
199+
if nd_a == 0 or nd_b == 0:
200+
return xp.multiply(a, b)
201+
202+
a_shape = a.shape
203+
b_shape = b.shape
204+
205+
# Equalise the shapes by prepending smaller one with 1s
206+
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
207+
b_shape = (1,) * max(0, nd_a - nd_b) + b_shape
208+
209+
# Insert empty dimensions
210+
a_arr = expand_dims(a, axis=tuple(range(nd_b - nd_a)), xp=xp)
211+
b_arr = expand_dims(b, axis=tuple(range(nd_a - nd_b)), xp=xp)
212+
213+
# Compute the product
214+
a_arr = expand_dims(a_arr, axis=tuple(range(1, nd_max * 2, 2)), xp=xp)
215+
b_arr = expand_dims(b_arr, axis=tuple(range(0, nd_max * 2, 2)), xp=xp)
216+
result = xp.multiply(a_arr, b_arr)
217+
218+
# Reshape back and return
219+
a_shape = xp.asarray(a_shape)
220+
b_shape = xp.asarray(b_shape)
221+
return xp.reshape(result, tuple(xp.multiply(a_shape, b_shape)))

0 commit comments

Comments
 (0)