Skip to content

Commit 6ba1c51

Browse files
committed
reduce atleast_test verbosity.
1 parent 664f853 commit 6ba1c51

File tree

1 file changed

+116
-100
lines changed

1 file changed

+116
-100
lines changed

tests/test_funcs.py

Lines changed: 116 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -294,147 +294,163 @@ def test_0D(self, xp: ModuleType):
294294
xp_assert_equal(y, xp.ones((1, 1, 1, 1, 1)))
295295

296296
@pytest.mark.parametrize(
297-
("x_data", "ndim", "expected_data"),
297+
("input_shape", "ndim", "expected_shape"),
298298
[
299-
# --- size-1 vector ---
300-
([3.0], 0, [3.0]),
301-
([3.0], 1, [3.0]),
302-
([3.0], 2, [[3.0]]),
303-
([3.0], 3, [[[3.0]]]),
304-
([3.0], 5, [[[[[3.0]]]]]),
305-
# --- size-2 vector ---
306-
([0.0, 1.0], 0, [0.0, 1.0]),
307-
([0.0, 1.0], 1, [0.0, 1.0]),
308-
([0.0, 1.0], 2, [[0.0, 1.0]]),
309-
([0.0, 1.0], 5, [[[[[0.0, 1.0]]]]]),
299+
((1,), 0, (1,)),
300+
((5,), 1, (5,)),
301+
((2,), 2, (1, 2)),
302+
((3,), 3, (1, 1, 3)),
303+
((2,), 5, (1, 1, 1, 1, 2)),
310304
],
311305
)
312-
def test_1D(
306+
def test_1D_shapes(
313307
self,
314-
x_data: NestedFloatList,
308+
input_shape: tuple[int],
315309
ndim: int,
316-
expected_data: NestedFloatList,
310+
expected_shape: tuple[int],
317311
xp: ModuleType,
318312
):
319-
x = xp.asarray(x_data)
320-
expected = xp.asarray(expected_data)
313+
n = math.prod(input_shape)
314+
x = xp.reshape(xp.asarray(list(range(n))), input_shape)
321315
y = atleast_nd(x, ndim=ndim)
322-
xp_assert_equal(y, expected)
323316

324-
@pytest.mark.parametrize(
325-
("x_data", "ndim", "expected_data"),
326-
[
327-
# --- size-1 vector ---
328-
([[3.0]], 0, [[3.0]]),
329-
([[3.0]], 1, [[3.0]]),
330-
([[3.0]], 2, [[3.0]]),
331-
([[3.0]], 3, [[[3.0]]]),
332-
([[3.0]], 5, [[[[[3.0]]]]]),
333-
# --- size-2 vector ---
334-
([[0.0], [1.0]], 0, [[0.0], [1.0]]),
335-
([[0.0, 1.0]], 1, [[0.0, 1.0]]),
336-
([[0.0, 1.0]], 2, [[0.0, 1.0]]),
337-
([[0.0], [1.0]], 3, [[[0.0], [1.0]]]),
338-
([[0.0, 1.0]], 5, [[[[[0.0, 1.0]]]]]),
339-
],
340-
)
341-
def test_2D(
342-
self,
343-
x_data: NestedFloatList,
344-
ndim: int,
345-
expected_data: NestedFloatList,
346-
xp: ModuleType,
347-
):
348-
x = xp.asarray(x_data)
349-
expected = xp.asarray(expected_data)
350-
y = atleast_nd(x, ndim=ndim)
351-
xp_assert_equal(y, expected)
317+
assert y.shape == expected_shape
318+
assert xp.sum(y) == int(n * (n - 1) / 2)
319+
320+
def test_1D_values(self, xp: ModuleType):
321+
x = xp.asarray([0, 1])
322+
323+
y = atleast_nd(x, ndim=0)
324+
xp_assert_equal(y, x)
325+
326+
y = atleast_nd(x, ndim=1)
327+
xp_assert_equal(y, x)
328+
329+
y = atleast_nd(x, ndim=2)
330+
xp_assert_equal(y, xp.asarray([[0, 1]]))
331+
332+
y = atleast_nd(x, ndim=5)
333+
xp_assert_equal(y, xp.asarray([[[[[0, 1]]]]]))
352334

353335
@pytest.mark.parametrize(
354-
("x_data", "ndim", "expected_data"),
336+
("input_shape", "ndim", "expected_shape"),
355337
[
356-
([[[0.0]], [[1.0]]], 0, [[[0.0]], [[1.0]]]),
357-
([[[0.0], [1.0]]], 1, [[[0.0], [1.0]]]),
358-
([[[0.0, 1.0]]], 2, [[[0.0, 1.0]]]),
359-
([[[0.0]], [[1.0]]], 3, [[[0.0]], [[1.0]]]),
360-
([[[0.0], [1.0]]], 5, [[[[[0.0], [1.0]]]]]),
338+
((2, 1), 0, (2, 1)),
339+
((5, 2), 1, (5, 2)),
340+
((2, 1), 2, (2, 1)),
341+
((3, 1), 3, (1, 3, 1)),
342+
((2, 8), 5, (1, 1, 1, 2, 8)),
361343
],
362344
)
363-
def test_3D(
345+
def test_2D_shapes(
364346
self,
365-
x_data: NestedFloatList,
347+
input_shape: tuple[int],
366348
ndim: int,
367-
expected_data: NestedFloatList,
349+
expected_shape: tuple[int],
368350
xp: ModuleType,
369351
):
370-
x = xp.asarray(x_data)
371-
expected = xp.asarray(expected_data)
352+
n = math.prod(input_shape)
353+
x = xp.reshape(xp.asarray(list(range(n))), input_shape)
372354
y = atleast_nd(x, ndim=ndim)
373-
xp_assert_equal(y, expected)
355+
356+
assert y.shape == expected_shape
357+
assert xp.sum(y) == int(n * (n - 1) / 2)
358+
359+
def test_2D_values(self, xp: ModuleType):
360+
x = xp.asarray([[3.0], [4.0]])
361+
362+
y = atleast_nd(x, ndim=0)
363+
xp_assert_equal(y, x)
364+
365+
y = atleast_nd(x, ndim=2)
366+
xp_assert_equal(y, x)
367+
368+
y = atleast_nd(x, ndim=3)
369+
xp_assert_equal(y, xp.asarray([[[3.0], [4.0]]]))
370+
371+
y = atleast_nd(x, ndim=5)
372+
xp_assert_equal(y, xp.asarray([[[[[3.0], [4.0]]]]]))
374373

375374
@pytest.mark.parametrize(
376-
("x_data", "ndim", "expected_data"),
375+
("input_shape", "ndim", "expected_shape"),
377376
[
378-
([[[[3.0], [2.0]]]], 0, [[[[3.0], [2.0]]]]),
379-
([[[[3.0, 2.0]]]], 2, [[[[3.0, 2.0]]]]),
380-
([[[[3.0]], [[2.0]]]], 4, [[[[3.0]], [[2.0]]]]),
381-
([[[[3.0]]], [[[2.0]]]], 5, [[[[[3.0]]], [[[2.0]]]]]),
377+
((2, 1, 1), 0, (2, 1, 1)),
378+
((1, 5, 2), 1, (1, 5, 2)),
379+
((2, 1, 1), 2, (2, 1, 1)),
380+
((1, 3, 1), 3, (1, 3, 1)),
381+
((2, 8, 1), 5, (1, 1, 2, 8, 1)),
382382
],
383383
)
384-
def test_4D(
384+
def test_3D_shapes(
385385
self,
386-
x_data: NestedFloatList,
386+
input_shape: tuple[int],
387387
ndim: int,
388-
expected_data: NestedFloatList,
388+
expected_shape: tuple[int],
389389
xp: ModuleType,
390390
):
391-
x = xp.asarray(x_data)
392-
expected = xp.asarray(expected_data)
391+
n = math.prod(input_shape)
392+
x = xp.reshape(xp.asarray(list(range(n))), input_shape)
393393
y = atleast_nd(x, ndim=ndim)
394-
xp_assert_equal(y, expected)
394+
395+
assert y.shape == expected_shape
396+
assert xp.sum(y) == int(n * (n - 1) / 2)
397+
398+
def test_3D_values(self, xp: ModuleType):
399+
x = xp.asarray([[[3.0], [2.0]]])
400+
401+
y = atleast_nd(x, ndim=0)
402+
xp_assert_equal(y, x)
403+
404+
y = atleast_nd(x, ndim=2)
405+
xp_assert_equal(y, x)
406+
407+
y = atleast_nd(x, ndim=3)
408+
xp_assert_equal(y, x)
409+
410+
y = atleast_nd(x, ndim=5)
411+
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))
395412

396413
@pytest.mark.parametrize(
397-
("x_data", "ndim", "expected_data"),
414+
("input_shape", "ndim", "expected_shape"),
398415
[
399-
([[[[[3.0]], [[2.0]], [[1.0]]]]], 0, [[[[[3.0]], [[2.0]], [[1.0]]]]]),
400-
([[[[[3.0, 2.0, 6.0]]]]], 2, [[[[[3.0, 2.0, 6.0]]]]]),
401-
(
402-
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
403-
4,
404-
[[[[[3.0]]], [[[2.0]]], [[[1.0]]]]],
405-
),
406-
(
407-
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
408-
6,
409-
[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]],
410-
),
411-
(
412-
[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]],
413-
9,
414-
[[[[[[[[[3.0]], [[1.0]]], [[[2.0]], [[1.0]]], [[[1.0]], [[1.0]]]]]]]]],
415-
),
416+
((2, 1, 1, 2, 1), 0, (2, 1, 1, 2, 1)),
417+
((1, 5, 2, 3, 2), 2, (1, 5, 2, 3, 2)),
418+
((2, 1, 1, 5, 2), 5, (2, 1, 1, 5, 2)),
419+
((1, 3, 1, 2, 1), 6, (1, 1, 3, 1, 2, 1)),
420+
((2, 8, 1, 9, 8), 9, (1, 1, 1, 1, 2, 8, 1, 9, 8)),
416421
],
417422
)
418-
def test_5D(
423+
def test_5D_shapes(
419424
self,
420-
x_data: NestedFloatList,
425+
input_shape: tuple[int],
421426
ndim: int,
422-
expected_data: NestedFloatList,
427+
expected_shape: tuple[int],
423428
xp: ModuleType,
424429
):
425-
x = xp.asarray(x_data)
426-
expected = xp.asarray(expected_data)
430+
n = math.prod(input_shape)
431+
x = xp.reshape(xp.asarray(list(range(n))), input_shape)
427432
y = atleast_nd(x, ndim=ndim)
428-
xp_assert_equal(y, expected)
429433

430-
def test_device(self, xp: ModuleType, device: Device):
431-
x = xp.asarray([1, 2, 3], device=device)
432-
assert get_device(atleast_nd(x, ndim=2)) == device
434+
assert y.shape == expected_shape
435+
assert xp.sum(y) == int(n * (n - 1) / 2)
433436

434-
def test_xp(self, xp: ModuleType):
435-
x = xp.asarray(1.0)
436-
y = atleast_nd(x, ndim=1, xp=xp)
437-
xp_assert_equal(y, xp.ones((1,)))
437+
def test_5D_values(self, xp: ModuleType):
438+
x = xp.asarray([[[[[3.0]], [[2.0]]]]])
439+
440+
y = atleast_nd(x, ndim=0)
441+
xp_assert_equal(y, x)
442+
443+
y = atleast_nd(x, ndim=4)
444+
xp_assert_equal(y, x)
445+
446+
y = atleast_nd(x, ndim=5)
447+
xp_assert_equal(y, x)
448+
449+
y = atleast_nd(x, ndim=6)
450+
xp_assert_equal(y, xp.asarray([[[[[[3.0]], [[2.0]]]]]]))
451+
452+
y = atleast_nd(x, ndim=9)
453+
xp_assert_equal(y, xp.asarray([[[[[[[[[3.0]], [[2.0]]]]]]]]]))
438454

439455

440456
class TestBroadcastShapes:

0 commit comments

Comments
 (0)