Skip to content

Commit feb1ca8

Browse files
Update test_flat.py
1 parent 2027b27 commit feb1ca8

File tree

1 file changed

+52
-33
lines changed

1 file changed

+52
-33
lines changed

dpnp/tests/test_flat.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,53 @@
1-
import numpy
1+
import numpy as np
22
import pytest
3-
4-
import dpnp as inp
5-
6-
7-
@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"])
8-
def test_flat(type):
9-
a = numpy.array([1, 0, 2, -3, -1, 2, 21, -9])
10-
ia = inp.array(a)
11-
12-
result = ia.flat[0]
13-
expected = a.flat[0]
14-
numpy.testing.assert_array_equal(expected, result)
15-
16-
17-
@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"])
18-
def test_flat2(type):
19-
a = numpy.arange(1, 7).reshape(2, 3)
20-
ia = inp.array(a)
21-
22-
result = ia.flat[3]
23-
expected = a.flat[3]
24-
numpy.testing.assert_array_equal(expected, result)
25-
26-
27-
@pytest.mark.parametrize("type", [numpy.int64], ids=["int64"])
28-
def test_flat3(type):
29-
a = numpy.arange(1, 7).reshape(2, 3).T
30-
ia = inp.array(a)
31-
32-
result = ia.flat[3]
33-
expected = a.flat[3]
34-
numpy.testing.assert_array_equal(expected, result)
3+
from numpy.testing import assert_array_equal
4+
5+
import dpnp
6+
7+
8+
class TestFlatiter:
9+
@pytest.mark.parametrize(
10+
"a, index",
11+
[
12+
(np.array([1, 0, 2, -3, -1, 2, 21, -9]), 0),
13+
(np.arange(1, 7).reshape(2, 3), 3),
14+
(np.arange(1, 7).reshape(2, 3).T, 3),
15+
],
16+
ids=["1D array", "2D array", "2D.T array"],
17+
)
18+
def test_flat_getitem(self, a, index):
19+
a_dp = dpnp.array(a)
20+
result = a_dp.flat[index]
21+
expected = a.flat[index]
22+
assert_array_equal(expected, result)
23+
24+
def test_flat_iteration(self):
25+
a = np.array([[1, 2], [3, 4]])
26+
a_dp = dpnp.array(a)
27+
for dp_val, np_val in zip(a_dp.flat, a.flat):
28+
assert dp_val == np_val
29+
30+
def test_init_error(self):
31+
with pytest.raises(TypeError):
32+
_ = dpnp.flatiter([1, 2, 3])
33+
34+
def test_flat_key_error(self):
35+
a_dp = dpnp.array(42)
36+
with pytest.raises(KeyError):
37+
_ = a_dp.flat[1]
38+
39+
def test_flat_invalid_key(self):
40+
a_dp = dpnp.array([1, 2, 3])
41+
flat = dpnp.flatiter(a_dp)
42+
# check __getitem__
43+
with pytest.raises(TypeError):
44+
_ = flat["invalid"]
45+
# check __setitem__
46+
with pytest.raises(TypeError):
47+
flat["invalid"] = 42
48+
49+
def test_flat_out_of_bounds(self):
50+
a_dp = dpnp.array([1, 2, 3])
51+
flat = dpnp.flatiter(a_dp)
52+
with pytest.raises(IndexError):
53+
_ = flat[10]

0 commit comments

Comments
 (0)