From 64c8c0f8fa9c249cab4facdcbaee535ec35d764e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 17:59:30 +0100 Subject: [PATCH] ENH: add tests for `eig` and `eigvals` --- array_api_tests/dtype_helpers.py | 20 ++++++++++++ array_api_tests/test_linalg.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/array_api_tests/dtype_helpers.py b/array_api_tests/dtype_helpers.py index f7fa306b..6555b89d 100644 --- a/array_api_tests/dtype_helpers.py +++ b/array_api_tests/dtype_helpers.py @@ -199,6 +199,26 @@ def is_scalar(x): return isinstance(x, (int, float, complex, bool)) +def complex_for_float(dtyp): + """For a real or complex dtype, return a matching complex dtype.""" + if api_version <= '2021.12': + raise TypeError("complex dtypes require api_version >= 2022.12.") + + if dtyp not in all_float_dtypes: + raise ValueError(f"expected a real dtype, got {dtyp}.") + + if dtyp == xp.float32: + return xp.complex64 + elif dtyp == xp.float64: + return xp.complex128 + elif dtyp == xp.complex64: + return xp.complex64 + elif dtype == xp.complex128: + return xp.complex128 + else: + raise ValueError(f"Unknown dtype {dtyp}.") + + def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping: dtype_value_pairs = [] for name, value in mapping.items(): diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 6f4608da..887a1681 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -331,6 +331,59 @@ def test_eigvalsh(x): # TODO: Test that res actually corresponds to the eigenvalues of x + +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') +@pytest.mark.min_version("2025.12") +@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes)) +def test_eig(x): + res = linalg.eig(x) + + _test_namedtuple(res, ['eigenvalues', 'eigenvectors'], 'eig') + + eigenvalues = res.eigenvalues + eigenvectors = res.eigenvectors + expected_dtype = dh.complex_for_float(x.dtype) + + ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvalues.dtype, + expected=expected_dtype, repr_name="eigenvalues.dtype") + ph.assert_result_shape("eig", in_shapes=[x.shape], + out_shape=eigenvalues.shape, + expected=x.shape[:-1], + repr_name="eigenvalues.shape") + + ph.assert_dtype("eig", in_dtype=x.dtype, out_dtype=eigenvectors.dtype, + expected=expected_dtype, repr_name="eigenvectors.dtype") + ph.assert_result_shape("eig", in_shapes=[x.shape], + out_shape=eigenvectors.shape, expected=x.shape, + repr_name="eigenvectors.shape") + + # TODO: Test that eigenvectors are orthonormal. + + _test_stacks(lambda x: linalg.eig(x).eigenvectors, x, + res=eigenvectors, dims=2) + + # TODO: Test that res actually corresponds to the eigenvalues and + # eigenvectors of x + + +@pytest.mark.unvectorized +@pytest.mark.xp_extension('linalg') +@pytest.mark.min_version("2025.12") +@given(x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes)) +def test_eigvals(x): + res = linalg.eigvals(x) + expected_dtype = dh.complex_for_float(x.dtype) + + ph.assert_dtype("eigvals", in_dtype=x.dtype, out_dtype=res.dtype, + expected=expected_dtype, repr_name="eigvals") + ph.assert_result_shape("eigvals", in_shapes=[x.shape], + out_shape=res.shape, expected=x.shape[:-1]) + # TODO: Test that res actually corresponds to the eigenvalues of x + + _test_stacks(linalg.eigvals, x, res=res, dims=1) + + @pytest.mark.unvectorized @pytest.mark.xp_extension('linalg') @given(x=invertible_matrices())