Skip to content

Commit e0f3e37

Browse files
authored
Merge pull request #427 from ev-br/eig_range
BUG: limit the range of elements in test_{eig,eigvals}
2 parents 41379d1 + 9c44b34 commit e0f3e37

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,11 +331,14 @@ def matrix_shapes(draw, stack_shapes=shapes()):
331331
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
332332

333333
@composite
334-
def finite_matrices(draw, shape=matrix_shapes()):
335-
return draw(arrays(dtype=floating_dtypes,
336-
shape=shape,
337-
elements=dict(allow_nan=False,
338-
allow_infinity=False)))
334+
def finite_matrices(draw, shape=matrix_shapes(), dtype=floating_dtypes, bound=None):
335+
# for now, only generate elements from (1, bound); cf symmetric_matrices
336+
elements = dict(allow_nan=False, allow_infinity=False)
337+
if bound is not None:
338+
elements.update(**dict(min_value=1, max_value=bound))
339+
340+
return draw(arrays(dtype=dtype, shape=shape, elements=elements))
341+
339342

340343
rtol_shared_matrix_shapes = shared(matrix_shapes())
341344
# Should we set a max_value here?

array_api_tests/test_linalg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def test_eigvalsh(x):
335335
@pytest.mark.unvectorized
336336
@pytest.mark.xp_extension('linalg')
337337
@pytest.mark.min_version("2025.12")
338-
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
338+
@given(x=finite_matrices(dtype=all_floating_dtypes(), shape=square_matrix_shapes, bound=10))
339339
def test_eig(x):
340340
res = linalg.eig(x)
341341

@@ -370,7 +370,7 @@ def test_eig(x):
370370
@pytest.mark.unvectorized
371371
@pytest.mark.xp_extension('linalg')
372372
@pytest.mark.min_version("2025.12")
373-
@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes))
373+
@given(x=finite_matrices(dtype=all_floating_dtypes(), shape=square_matrix_shapes, bound=10))
374374
def test_eigvals(x):
375375
res = linalg.eigvals(x)
376376
expected_dtype = dh.complex_dtype_for(x.dtype)

0 commit comments

Comments
 (0)