Skip to content

Commit d4b7994

Browse files
committed
refactor to use pytest.mark.parameterize
1 parent e5732c7 commit d4b7994

File tree

1 file changed

+15
-22
lines changed

1 file changed

+15
-22
lines changed

numpy/_core/tests/test_numeric.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,29 +1710,22 @@ def test_sparse(self):
17101710
assert_equal(np.nonzero(c)[0],
17111711
np.concatenate((np.arange(10 + i, 20 + i), [20 + i*2])))
17121712

1713-
def test_nonzero_dtypes(self):
1713+
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
1714+
def test_nonzero_float_dtypes(self, dtype):
17141715
rng = np.random.default_rng(seed = 10)
1715-
zero_indices = np.arange(50)
1716-
1717-
# test for different dtypes
1718-
types = [bool, np.float32, np.float64]
1719-
sample = ((2**33)*rng.normal(size=100))
1720-
for dtype in types:
1721-
x = sample.astype(dtype)
1722-
rng.shuffle(zero_indices)
1723-
x[zero_indices] = 0
1724-
idxs = np.nonzero(x)[0]
1725-
assert_equal(np.array_equal(np.where(x != 0)[0], idxs), True)
1726-
1727-
integer_types = [np.int8, np.int16, np.int32, np.int64,
1728-
np.uint8, np.uint16, np.uint32, np.uint64]
1729-
sample = rng.integers(0, 255, size=100)
1730-
for dtype in integer_types:
1731-
x = sample.astype(dtype)
1732-
rng.shuffle(zero_indices)
1733-
x[zero_indices] = 0
1734-
idxs = np.nonzero(x)[0]
1735-
assert_equal(np.array_equal(np.where(x != 0)[0], idxs), True)
1716+
x = ((2**33)*rng.normal(size=100)).astype(dtype)
1717+
x[rng.choice(50, size=100)] = 0
1718+
idxs = np.nonzero(x)[0]
1719+
assert_equal(np.array_equal(np.where(x != 0)[0], idxs), True)
1720+
1721+
@pytest.mark.parametrize('dtype', [bool, np.int8, np.int16, np.int32, np.int64,
1722+
np.uint8, np.uint16, np.uint32, np.uint64])
1723+
def test_nonzero_integer_dtypes(self, dtype):
1724+
rng = np.random.default_rng(seed = 10)
1725+
x = rng.integers(0, 255, size=100).astype(dtype)
1726+
x[rng.choice(50, size=100)] = 0
1727+
idxs = np.nonzero(x)[0]
1728+
assert_equal(np.array_equal(np.where(x != 0)[0], idxs), True)
17361729

17371730

17381731
def test_return_type(self):

0 commit comments

Comments
 (0)