|
1 | | -import numpy |
| 1 | +import numpy as np |
2 | 2 | import pytest |
3 | | - |
4 | | -import dpnp as inp |
5 | | - |
6 | | - |
7 | | -@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"]) |
8 | | -def test_flat(type): |
9 | | - a = numpy.array([1, 0, 2, -3, -1, 2, 21, -9]) |
10 | | - ia = inp.array(a) |
11 | | - |
12 | | - result = ia.flat[0] |
13 | | - expected = a.flat[0] |
14 | | - numpy.testing.assert_array_equal(expected, result) |
15 | | - |
16 | | - |
17 | | -@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"]) |
18 | | -def test_flat2(type): |
19 | | - a = numpy.arange(1, 7).reshape(2, 3) |
20 | | - ia = inp.array(a) |
21 | | - |
22 | | - result = ia.flat[3] |
23 | | - expected = a.flat[3] |
24 | | - numpy.testing.assert_array_equal(expected, result) |
25 | | - |
26 | | - |
27 | | -@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"]) |
28 | | -def test_flat3(type): |
29 | | - a = numpy.arange(1, 7).reshape(2, 3).T |
30 | | - ia = inp.array(a) |
31 | | - |
32 | | - result = ia.flat[3] |
33 | | - expected = a.flat[3] |
34 | | - numpy.testing.assert_array_equal(expected, result) |
| 3 | +from numpy.testing import assert_array_equal |
| 4 | + |
| 5 | +import dpnp |
| 6 | + |
| 7 | + |
| 8 | +class TestFlatiter: |
| 9 | + @pytest.mark.parametrize( |
| 10 | + "a, index", |
| 11 | + [ |
| 12 | + (np.array([1, 0, 2, -3, -1, 2, 21, -9]), 0), |
| 13 | + (np.arange(1, 7).reshape(2, 3), 3), |
| 14 | + (np.arange(1, 7).reshape(2, 3).T, 3), |
| 15 | + ], |
| 16 | + ids=["1D array", "2D array", "2D.T array"], |
| 17 | + ) |
| 18 | + def test_flat_getitem(self, a, index): |
| 19 | + a_dp = dpnp.array(a) |
| 20 | + result = a_dp.flat[index] |
| 21 | + expected = a.flat[index] |
| 22 | + assert_array_equal(expected, result) |
| 23 | + |
| 24 | + def test_flat_iteration(self): |
| 25 | + a = np.array([[1, 2], [3, 4]]) |
| 26 | + a_dp = dpnp.array(a) |
| 27 | + for dp_val, np_val in zip(a_dp.flat, a.flat): |
| 28 | + assert dp_val == np_val |
| 29 | + |
| 30 | + def test_init_error(self): |
| 31 | + with pytest.raises(TypeError): |
| 32 | + _ = dpnp.flatiter([1, 2, 3]) |
| 33 | + |
| 34 | + def test_flat_key_error(self): |
| 35 | + a_dp = dpnp.array(42) |
| 36 | + with pytest.raises(KeyError): |
| 37 | + _ = a_dp.flat[1] |
| 38 | + |
| 39 | + def test_flat_invalid_key(self): |
| 40 | + a_dp = dpnp.array([1, 2, 3]) |
| 41 | + flat = dpnp.flatiter(a_dp) |
| 42 | + # check __getitem__ |
| 43 | + with pytest.raises(TypeError): |
| 44 | + _ = flat["invalid"] |
| 45 | + # check __setitem__ |
| 46 | + with pytest.raises(TypeError): |
| 47 | + flat["invalid"] = 42 |
| 48 | + |
| 49 | + def test_flat_out_of_bounds(self): |
| 50 | + a_dp = dpnp.array([1, 2, 3]) |
| 51 | + flat = dpnp.flatiter(a_dp) |
| 52 | + with pytest.raises(IndexError): |
| 53 | + _ = flat[10] |
0 commit comments