@@ -46,3 +46,176 @@ def atleast_nd(x: Array, *, ndim: int, xp: ModuleType) -> Array:
46
46
x = xp .expand_dims (x , axis = 0 )
47
47
x = atleast_nd (x , ndim = ndim , xp = xp )
48
48
return x
49
+
50
+
51
+ def expand_dims (a : Array , * , axis : tuple [int ] = (0 ,), xp : ModuleType ):
52
+ """
53
+ Expand the shape of an array.
54
+
55
+ Insert a new axis that will appear at the `axis` position in the expanded
56
+ array shape.
57
+
58
+ This is ``xp.expand_dims`` for ``axis`` an int *or a tuple of ints*.
59
+ Equivalent to ``numpy.expand_dims`` for NumPy arrays.
60
+
61
+ Parameters
62
+ ----------
63
+ a : array
64
+ axis : int or tuple of ints
65
+ Position(s) in the expanded axes where the new axis (or axes) is/are placed.
66
+ xp : array_namespace
67
+ The standard-compatible namespace for `a`.
68
+
69
+ Returns
70
+ -------
71
+ res : array
72
+ `a` with an expanded shape.
73
+
74
+ Examples
75
+ --------
76
+ # >>> import numpy as np
77
+ # >>> x = np.array([1, 2])
78
+ # >>> x.shape
79
+ # (2,)
80
+
81
+ # The following is equivalent to ``x[np.newaxis, :]`` or ``x[np.newaxis]``:
82
+
83
+ # >>> y = np.expand_dims(x, axis=0)
84
+ # >>> y
85
+ # array([[1, 2]])
86
+ # >>> y.shape
87
+ # (1, 2)
88
+
89
+ # The following is equivalent to ``x[:, np.newaxis]``:
90
+
91
+ # >>> y = np.expand_dims(x, axis=1)
92
+ # >>> y
93
+ # array([[1],
94
+ # [2]])
95
+ # >>> y.shape
96
+ # (2, 1)
97
+
98
+ # ``axis`` may also be a tuple:
99
+
100
+ # >>> y = np.expand_dims(x, axis=(0, 1))
101
+ # >>> y
102
+ # array([[[1, 2]]])
103
+
104
+ # >>> y = np.expand_dims(x, axis=(2, 0))
105
+ # >>> y
106
+ # array([[[1],
107
+ # [2]]])
108
+
109
+ # Note that some examples may use ``None`` instead of ``np.newaxis``. These
110
+ # are the same objects:
111
+
112
+ # >>> np.newaxis is None
113
+ # True
114
+
115
+ """
116
+ if not isinstance (axis , tuple ):
117
+ axis = (axis ,)
118
+ for i in axis :
119
+ a = xp .expand_dims (a , axis = i , xp = xp )
120
+ return a
121
+
122
+
123
+ def kron (a : Array , b : Array , * , xp : ModuleType ):
124
+ """
125
+ Kronecker product of two arrays.
126
+
127
+ Computes the Kronecker product, a composite array made of blocks of the
128
+ second array scaled by the first.
129
+
130
+ Equivalent to ``numpy.kron`` for NumPy arrays.
131
+
132
+ Parameters
133
+ ----------
134
+ a, b : array
135
+ xp : array_namespace
136
+ The standard-compatible namespace for `a` and `b`.
137
+
138
+ Returns
139
+ -------
140
+ res : array
141
+
142
+ Notes
143
+ -----
144
+ The function assumes that the number of dimensions of `a` and `b`
145
+ are the same, if necessary prepending the smallest with ones.
146
+ If ``a.shape = (r0,r1,..,rN)`` and ``b.shape = (s0,s1,...,sN)``,
147
+ the Kronecker product has shape ``(r0*s0, r1*s1, ..., rN*SN)``.
148
+ The elements are products of elements from `a` and `b`, organized
149
+ explicitly by::
150
+
151
+ kron(a,b)[k0,k1,...,kN] = a[i0,i1,...,iN] * b[j0,j1,...,jN]
152
+
153
+ where::
154
+
155
+ kt = it * st + jt, t = 0,...,N
156
+
157
+ In the common 2-D case (N=1), the block structure can be visualized::
158
+
159
+ [[ a[0,0]*b, a[0,1]*b, ... , a[0,-1]*b ],
160
+ [ ... ... ],
161
+ [ a[-1,0]*b, a[-1,1]*b, ... , a[-1,-1]*b ]]
162
+
163
+
164
+ Examples
165
+ --------
166
+ # >>> import numpy as np
167
+ # >>> np.kron([1,10,100], [5,6,7])
168
+ # array([ 5, 6, 7, ..., 500, 600, 700])
169
+ # >>> np.kron([5,6,7], [1,10,100])
170
+ # array([ 5, 50, 500, ..., 7, 70, 700])
171
+
172
+ # >>> np.kron(np.eye(2), np.ones((2,2)))
173
+ # array([[1., 1., 0., 0.],
174
+ # [1., 1., 0., 0.],
175
+ # [0., 0., 1., 1.],
176
+ # [0., 0., 1., 1.]])
177
+
178
+ # >>> a = np.arange(100).reshape((2,5,2,5))
179
+ # >>> b = np.arange(24).reshape((2,3,4))
180
+ # >>> c = np.kron(a,b)
181
+ # >>> c.shape
182
+ # (2, 10, 6, 20)
183
+ # >>> I = (1,3,0,2)
184
+ # >>> J = (0,2,1)
185
+ # >>> J1 = (0,) + J # extend to ndim=4
186
+ # >>> S1 = (1,) + b.shape
187
+ # >>> K = tuple(np.array(I) * np.array(S1) + np.array(J1))
188
+ # >>> c[K] == a[I]*b[J]
189
+ # True
190
+
191
+ """
192
+
193
+ b = xp .asarray (b )
194
+ singletons = (1 ,) * (b .ndim - a .ndim )
195
+ a = xp .broadcast_to (xp .asarray (a ), singletons + a .shape )
196
+
197
+ nd_b , nd_a = b .ndim , a .ndim
198
+ nd_max = max (nd_b , nd_a )
199
+ if nd_a == 0 or nd_b == 0 :
200
+ return xp .multiply (a , b )
201
+
202
+ a_shape = a .shape
203
+ b_shape = b .shape
204
+
205
+ # Equalise the shapes by prepending smaller one with 1s
206
+ a_shape = (1 ,) * max (0 , nd_b - nd_a ) + a_shape
207
+ b_shape = (1 ,) * max (0 , nd_a - nd_b ) + b_shape
208
+
209
+ # Insert empty dimensions
210
+ a_arr = expand_dims (a , axis = tuple (range (nd_b - nd_a )), xp = xp )
211
+ b_arr = expand_dims (b , axis = tuple (range (nd_a - nd_b )), xp = xp )
212
+
213
+ # Compute the product
214
+ a_arr = expand_dims (a_arr , axis = tuple (range (1 , nd_max * 2 , 2 )), xp = xp )
215
+ b_arr = expand_dims (b_arr , axis = tuple (range (0 , nd_max * 2 , 2 )), xp = xp )
216
+ result = xp .multiply (a_arr , b_arr )
217
+
218
+ # Reshape back and return
219
+ a_shape = xp .asarray (a_shape )
220
+ b_shape = xp .asarray (b_shape )
221
+ return xp .reshape (result , tuple (xp .multiply (a_shape , b_shape )))
0 commit comments