Skip to content

Commit 034df74

Browse files
authored
BUG: Fix delegation behaviour with atleast_3d (data-apis#514)
1 parent 20c314e commit 034df74

File tree

2 files changed

+109
-20
lines changed

2 files changed

+109
-20
lines changed

src/array_api_extra/_delegation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
6868
if xp is None:
6969
xp = array_namespace(x)
7070

71-
if 1 <= ndim <= 3 and (
71+
if 1 <= ndim <= 2 and (
7272
is_numpy_namespace(xp)
7373
or is_jax_namespace(xp)
7474
or is_dask_namespace(xp)

tests/test_funcs.py

Lines changed: 108 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
lazy_xp_function(setdiff1d, jax_jit=False)
5555
lazy_xp_function(sinc)
5656

57+
NestedFloatList = list[float] | list["NestedFloatList"]
58+
5759

5860
class TestApplyWhere:
5961
@staticmethod
@@ -291,7 +293,31 @@ def test_0D(self, xp: ModuleType):
291293
y = atleast_nd(x, ndim=5)
292294
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1)))
293295

294-
def test_1D(self, xp: ModuleType):
296+
@pytest.mark.parametrize(
297+
("input_shape", "ndim", "expected_shape"),
298+
[
299+
((1,), 0, (1,)),
300+
((5,), 1, (5,)),
301+
((2,), 2, (1, 2)),
302+
((3,), 3, (1, 1, 3)),
303+
((2,), 5, (1, 1, 1, 1, 2)),
304+
],
305+
)
306+
def test_1D_shapes(
307+
self,
308+
input_shape: tuple[int],
309+
ndim: int,
310+
expected_shape: tuple[int],
311+
xp: ModuleType,
312+
):
313+
n = math.prod(input_shape)
314+
x = xp.asarray(np.arange(n).reshape(input_shape))
315+
y = atleast_nd(x, ndim=ndim)
316+
317+
assert y.shape == expected_shape
318+
assert xp.sum(y) == int(n * (n - 1) / 2)
319+
320+
def test_1D_values(self, xp: ModuleType):
295321
x = xp.asarray([0, 1])
296322

297323
y = atleast_nd(x, ndim=0)
@@ -306,8 +332,32 @@ def test_1D(self, xp: ModuleType):
306332
y = atleast_nd(x, ndim=5)
307333
xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]]))
308334

309-
def test_2D(self, xp: ModuleType):
310-
x = xp.asarray([[3.0]])
335+
@pytest.mark.parametrize(
336+
("input_shape", "ndim", "expected_shape"),
337+
[
338+
((2, 1), 0, (2, 1)),
339+
((5, 2), 1, (5, 2)),
340+
((2, 1), 2, (2, 1)),
341+
((3, 1), 3, (1, 3, 1)),
342+
((2, 8), 5, (1, 1, 1, 2, 8)),
343+
],
344+
)
345+
def test_2D_shapes(
346+
self,
347+
input_shape: tuple[int],
348+
ndim: int,
349+
expected_shape: tuple[int],
350+
xp: ModuleType,
351+
):
352+
n = math.prod(input_shape)
353+
x = xp.asarray(np.arange(n).reshape(input_shape))
354+
y = atleast_nd(x, ndim=ndim)
355+
356+
assert y.shape == expected_shape
357+
assert xp.sum(y) == int(n * (n - 1) / 2)
358+
359+
def test_2D_values(self, xp: ModuleType):
360+
x = xp.asarray([[3.0], [4.0]])
311361

312362
y = atleast_nd(x, ndim=0)
313363
xp_assert_equal(y, x)
@@ -316,12 +366,36 @@ def test_2D(self, xp: ModuleType):
316366
xp_assert_equal(y, x)
317367

318368
y = atleast_nd(x, ndim=3)
319-
xp_assert_equal(y, 3 * xp.ones((1, 1, 1)))
369+
xp_assert_equal(y, xp.asarray([[[3.0], [4.0]]]))
320370

321371
y = atleast_nd(x, ndim=5)
322-
xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))
372+
xp_assert_equal(y, xp.asarray([[[[[3.0], [4.0]]]]]))
373+
374+
@pytest.mark.parametrize(
375+
("input_shape", "ndim", "expected_shape"),
376+
[
377+
((2, 1, 1), 0, (2, 1, 1)),
378+
((1, 5, 2), 1, (1, 5, 2)),
379+
((2, 1, 1), 2, (2, 1, 1)),
380+
((1, 3, 1), 3, (1, 3, 1)),
381+
((2, 8, 1), 5, (1, 1, 2, 8, 1)),
382+
],
383+
)
384+
def test_3D_shapes(
385+
self,
386+
input_shape: tuple[int],
387+
ndim: int,
388+
expected_shape: tuple[int],
389+
xp: ModuleType,
390+
):
391+
n = math.prod(input_shape)
392+
x = xp.asarray(np.arange(n).reshape(input_shape))
393+
y = atleast_nd(x, ndim=ndim)
394+
395+
assert y.shape == expected_shape
396+
assert xp.sum(y) == int(n * (n - 1) / 2)
323397

324-
def test_3D(self, xp: ModuleType):
398+
def test_3D_values(self, xp: ModuleType):
325399
x = xp.asarray([[[3.0], [2.0]]])
326400

327401
y = atleast_nd(x, ndim=0)
@@ -336,8 +410,32 @@ def test_3D(self, xp: ModuleType):
336410
y = atleast_nd(x, ndim=5)
337411
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))
338412

339-
def test_5D(self, xp: ModuleType):
340-
x = xp.ones((1, 1, 1, 1, 1))
413+
@pytest.mark.parametrize(
414+
("input_shape", "ndim", "expected_shape"),
415+
[
416+
((2, 1, 1, 2, 1), 0, (2, 1, 1, 2, 1)),
417+
((1, 5, 2, 3, 2), 2, (1, 5, 2, 3, 2)),
418+
((2, 1, 1, 5, 2), 5, (2, 1, 1, 5, 2)),
419+
((1, 3, 1, 2, 1), 6, (1, 1, 3, 1, 2, 1)),
420+
((2, 8, 1, 9, 8), 9, (1, 1, 1, 1, 2, 8, 1, 9, 8)),
421+
],
422+
)
423+
def test_5D_shapes(
424+
self,
425+
input_shape: tuple[int],
426+
ndim: int,
427+
expected_shape: tuple[int],
428+
xp: ModuleType,
429+
):
430+
n = math.prod(input_shape)
431+
x = xp.asarray(np.arange(n).reshape(input_shape))
432+
y = atleast_nd(x, ndim=ndim)
433+
434+
assert y.shape == expected_shape
435+
assert xp.sum(y) == int(n * (n - 1) / 2)
436+
437+
def test_5D_values(self, xp: ModuleType):
438+
x = xp.asarray([[[[[3.0]], [[2.0]]]]])
341439

342440
y = atleast_nd(x, ndim=0)
343441
xp_assert_equal(y, x)
@@ -349,19 +447,10 @@ def test_5D(self, xp: ModuleType):
349447
xp_assert_equal(y, x)
350448

351449
y = atleast_nd(x, ndim=6)
352-
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1)))
450+
xp_assert_equal(y, xp.asarray([[[[[[3.0]], [[2.0]]]]]]))
353451

354452
y = atleast_nd(x, ndim=9)
355-
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1, 1, 1, 1, 1)))
356-
357-
def test_device(self, xp: ModuleType, device: Device):
358-
x = xp.asarray([1, 2, 3], device=device)
359-
assert get_device(atleast_nd(x, ndim=2)) == device
360-
361-
def test_xp(self, xp: ModuleType):
362-
x = xp.asarray(1.0)
363-
y = atleast_nd(x, ndim=1, xp=xp)
364-
xp_assert_equal(y, xp.ones((1,)))
453+
xp_assert_equal(y, xp.asarray([[[[[[[[[3.0]], [[2.0]]]]]]]]]))
365454

366455

367456
class TestBroadcastShapes:

0 commit comments

Comments
 (0)