@@ -197,6 +197,55 @@ def expand_dims(
197
197
return _funcs .expand_dims (a , axis = axis , xp = xp )
198
198
199
199
200
+ def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
201
+ """
202
+ Recursively expand the dimension of an array to at least `ndim`.
203
+
204
+ Parameters
205
+ ----------
206
+ x : array
207
+ Input array.
208
+ ndim : int
209
+ The minimum number of dimensions for the result.
210
+ xp : array_namespace, optional
211
+ The standard-compatible namespace for `x`. Default: infer.
212
+
213
+ Returns
214
+ -------
215
+ array
216
+ An array with ``res.ndim`` >= `ndim`.
217
+ If ``x.ndim`` >= `ndim`, `x` is returned.
218
+ If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
219
+ until ``res.ndim`` equals `ndim`.
220
+
221
+ Examples
222
+ --------
223
+ >>> import array_api_strict as xp
224
+ >>> import array_api_extra as xpx
225
+ >>> x = xp.asarray([1])
226
+ >>> xpx.atleast_nd(x, ndim=3, xp=xp)
227
+ Array([[[1]]], dtype=array_api_strict.int64)
228
+
229
+ >>> x = xp.asarray([[[1, 2],
230
+ ... [3, 4]]])
231
+ >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
232
+ True
233
+ """
234
+ if xp is None :
235
+ xp = array_namespace (x )
236
+
237
+ if 1 <= ndim <= 3 and (
238
+ is_numpy_namespace (xp )
239
+ or is_jax_namespace (xp )
240
+ or is_dask_namespace (xp )
241
+ or is_cupy_namespace (xp )
242
+ or is_torch_namespace (xp )
243
+ ):
244
+ return getattr (xp , f"atleast_{ ndim } d" )(x )
245
+
246
+ return _funcs .atleast_nd (x , ndim = ndim , xp = xp )
247
+
248
+
200
249
def isclose (
201
250
a : Array | complex ,
202
251
b : Array | complex ,
0 commit comments