diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e3d2adf461..7298d5df61 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -78,7 +78,7 @@ jobs: install-jax: [0] install-torch: [0] part: - - "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests/scan" - "tests/sparse" @@ -97,9 +97,9 @@ jobs: part: "tests/tensor/test_math.py" - fast-compile: 1 float32: 1 - - part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" float32: 1 - - part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link" + - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" fast-compile: 1 include: - install-numba: 1 diff --git a/pyproject.toml b/pyproject.toml index 81fe82c79c..42c2289dde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,9 +116,10 @@ versionfile_source = "pytensor/_version.py" versionfile_build = "pytensor/_version.py" tag_prefix = "rel-" -[tool.pytest] -addopts = "--durations=50 --doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" -testpaths = "tests/" +[tool.pytest.ini_options] +addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py" +testpaths = ["pytensor/", "tests/"] +xfail_strict = true [tool.ruff] line-length = 88 diff --git a/pytensor/link/pytorch/dispatch/shape.py b/pytensor/link/pytorch/dispatch/shape.py index 7633e28e01..e249a81a70 100644 --- a/pytensor/link/pytorch/dispatch/shape.py +++ b/pytensor/link/pytorch/dispatch/shape.py @@ -15,7 +15,7 @@ def reshape(x, shape): @pytorch_funcify.register(Shape) def pytorch_funcify_Shape(op, **kwargs): def shape(x): - return x.shape + return torch.tensor(x.shape) return shape diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py index 4f53ec29f7..75e7ec0776 100644 --- a/pytensor/link/pytorch/dispatch/subtensor.py +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -34,8 +34,13 @@ def subtensor(x, *flattened_indices): @pytorch_funcify.register(MakeSlice) def pytorch_funcify_makeslice(op, **kwargs): - def makeslice(*x): - return slice(x) + def makeslice(start, stop, step): + # Torch does not like numpy integers in indexing slices + return slice( + None if start is None else int(start), + None if stop is None else int(stop), + None if step is None else int(step), + ) return makeslice diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index 6db6ae2638..e7093a82bd 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -240,7 +240,7 @@ class SLogDet(Op): """ __props__ = () - gufunc_signature = "(m, m)->(),()" + gufunc_signature = "(m,m)->(),()" gufunc_spec = ("numpy.linalg.slogdet", 1, 2) def make_node(self, x): diff --git a/tests/link/jax/test_basic.py b/tests/link/jax/test_basic.py index 5cd2bd54c6..5e783984e0 100644 --- a/tests/link/jax/test_basic.py +++ b/tests/link/jax/test_basic.py @@ -6,13 +6,15 @@ from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function -from pytensor.compile.mode import get_mode +from pytensor.compile.mode import JAX, Mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op, get_test_value from pytensor.ifelse import ifelse +from pytensor.link.jax import JAXLinker from pytensor.raise_op import assert_op from pytensor.tensor.type import dscalar, matrices, scalar, vector @@ -26,9 +28,9 @@ def set_pytensor_flags(): jax = pytest.importorskip("jax") -# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs -jax_mode = get_mode("JAX") -py_mode = get_mode("FAST_COMPILE") +optimizer = RewriteDatabaseQuery(include=["jax"], exclude=JAX._optimizer.exclude) +jax_mode = Mode(linker=JAXLinker(), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) def compare_jax_and_py( diff --git a/tests/link/jax/test_einsum.py b/tests/link/jax/test_einsum.py index 9a55670c64..5761563066 100644 --- a/tests/link/jax/test_einsum.py +++ b/tests/link/jax/test_einsum.py @@ -1,8 +1,9 @@ import numpy as np import pytest -import pytensor import pytensor.tensor as pt +from pytensor.graph import FunctionGraph +from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") @@ -19,12 +20,10 @@ def test_jax_einsum(): pt.tensor(name, shape=shape) for name, shape in zip("xyz", shapes) ) out = pt.einsum(subscripts, x_pt, y_pt, z_pt) - f = pytensor.function([x_pt, y_pt, z_pt], out, mode="JAX") + fg = FunctionGraph([x_pt, y_pt, z_pt], [out]) + compare_jax_and_py(fg, [x, y, z]) - np.testing.assert_allclose(f(x, y, z), np.einsum(subscripts, x, y, z)) - -@pytest.mark.xfail(raises=NotImplementedError) def test_ellipsis_einsum(): subscripts = "...i,...i->..." x = np.random.rand(2, 5) @@ -33,6 +32,5 @@ def test_ellipsis_einsum(): x_pt = pt.tensor("x", shape=x.shape) y_pt = pt.tensor("y", shape=y.shape) out = pt.einsum(subscripts, x_pt, y_pt) - f = pytensor.function([x_pt, y_pt], out, mode="JAX") - - np.testing.assert_allclose(f(x, y), np.einsum(subscripts, x, y)) + fg = FunctionGraph([x_pt, y_pt], [out]) + compare_jax_and_py(fg, [x, y]) diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 94c442b165..1427413379 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -1,59 +1,52 @@ import numpy as np import pytest -from packaging.version import parse as version_parse import pytensor.tensor.basic as ptb from pytensor.configdefaults import config from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops -from pytensor.tensor.type import matrix +from pytensor.tensor.type import matrix, tensor from tests.link.jax.test_basic import compare_jax_and_py jax = pytest.importorskip("jax") -def set_test_value(x, v): - x.tag.test_value = v - return x - - def test_extra_ops(): a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.cumsum(a, axis=0) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.cumprod(a, axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.diff(a, n=2, axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) out = pt_extra_ops.repeat(a, (3, 3), axis=1) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) c = ptb.as_tensor(5) - out = pt_extra_ops.fill_diagonal(a, c) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.fill_diagonal_offset(a, c, c) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) with pytest.raises(NotImplementedError): out = pt_extra_ops.Unique(axis=1)(a) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) indices = np.arange(np.prod((3, 4))) out = pt_extra_ops.unravel_index(indices, (3, 4), order="C") @@ -63,40 +56,30 @@ def test_extra_ops(): ) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="JAX Numpy API does not support dynamic shapes", -) -def test_extra_ops_dynamic_shapes(): - a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) - - # This function also cannot take symbolic input. - c = ptb.as_tensor(5) +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_bartlett_dynamic_shape(): + c = tensor(shape=(), dtype=int) out = pt_extra_ops.bartlett(c) fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [np.array(5)]) - multi_index = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4)) - out = pt_extra_ops.ravel_multi_index(multi_index, (3, 4)) - fgraph = FunctionGraph([], [out]) - compare_jax_and_py( - fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False - ) - # The inputs are "concrete", yet it still has problems? - out = pt_extra_ops.Unique()( - ptb.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2))) - ) +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_ravel_multi_index_dynamic_shape(): + x_test, y_test = np.unravel_index(np.arange(np.prod((3, 4))), (3, 4)) + + x = tensor(shape=(None,), dtype=int) + y = tensor(shape=(None,), dtype=int) + out = pt_extra_ops.ravel_multi_index((x, y), (3, 4)) fgraph = FunctionGraph([], [out]) - compare_jax_and_py(fgraph, []) + compare_jax_and_py(fgraph, [x_test, y_test]) -@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs") -def test_unique_nonconcrete(): +@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") +def test_unique_dynamic_shape(): a = matrix("a") - a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2)) + a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) out = pt_extra_ops.Unique()(a) fgraph = FunctionGraph([a], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) + compare_jax_and_py(fgraph, [a_test]) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index dfbc888e30..f9ae5d00c1 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -705,7 +705,7 @@ def test_multinomial(): n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) - g_fn = compile_random_function([], g, mode=jax_mode) + g_fn = compile_random_function([], g, mode="JAX") samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) np.testing.assert_allclose( diff --git a/tests/link/jax/test_scan.py b/tests/link/jax/test_scan.py index 61edacbc7b..ae64cad4c0 100644 --- a/tests/link/jax/test_scan.py +++ b/tests/link/jax/test_scan.py @@ -32,7 +32,7 @@ def test_scan_sit_sot(view): xs = xs[view] fg = FunctionGraph([x0], [xs]) test_input_vals = [np.e] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-1,), slice(-4, -1, None)]) @@ -47,7 +47,7 @@ def test_scan_mit_sot(view): xs = xs[view] fg = FunctionGraph([x0], [xs]) test_input_vals = [np.full((3,), np.e)] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view_x", [None, (-1,), slice(-4, -1, None)]) @@ -74,7 +74,7 @@ def step(xtm3, xtm1, ytm4, ytm2): fg = FunctionGraph([x0, y0], [xs, ys]) test_input_vals = [np.full((3,), np.e), np.full((4,), np.pi)] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("view", [None, (-2,), slice(None, None, 2)]) @@ -283,7 +283,7 @@ def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta): gamma_val, delta_val, ] - compare_jax_and_py(out_fg, test_input_vals) + compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") def test_scan_mitsot_with_nonseq(): @@ -316,7 +316,7 @@ def input_step_fn(y_tm1, y_tm3, a): out_fg = FunctionGraph([a_pt], [y_scan_pt]) test_input_vals = [np.array(10.0).astype(config.floatX)] - compare_jax_and_py(out_fg, test_input_vals) + compare_jax_and_py(out_fg, test_input_vals, jax_mode="JAX") @pytest.mark.parametrize("x0_func", [dvector, dmatrix]) @@ -334,7 +334,6 @@ def test_nd_scan_sit_sot(x0_func, A_func): non_sequences=[A], outputs_info=[x0], n_steps=n_steps, - mode=get_mode("JAX"), ) x0_val = ( @@ -346,7 +345,7 @@ def test_nd_scan_sit_sot(x0_func, A_func): fg = FunctionGraph([x0, A], [xs]) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_seq(): @@ -362,7 +361,6 @@ def test_nd_scan_sit_sot_with_seq(): non_sequences=[A], sequences=[x], n_steps=n_steps, - mode=get_mode("JAX"), ) x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) @@ -370,7 +368,7 @@ def test_nd_scan_sit_sot_with_seq(): fg = FunctionGraph([x, A], [xs]) test_input_vals = [x_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_mit_sot(): @@ -384,7 +382,6 @@ def test_nd_scan_mit_sot(): outputs_info=[{"initial": x0, "taps": [-3, -1]}], non_sequences=[A, B], n_steps=10, - mode=get_mode("JAX"), ) fg = FunctionGraph([x0, A, B], [xs]) @@ -393,7 +390,7 @@ def test_nd_scan_mit_sot(): B_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val, B_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_nd_scan_sit_sot_with_carry(): @@ -417,7 +414,7 @@ def step(x, A): A_val = np.eye(3, dtype=config.floatX) test_input_vals = [x0_val, A_val] - compare_jax_and_py(fg, test_input_vals) + compare_jax_and_py(fg, test_input_vals, jax_mode="JAX") def test_default_mode_excludes_incompatible_rewrites(): @@ -426,7 +423,7 @@ def test_default_mode_excludes_incompatible_rewrites(): B = matrix("B") out, _ = scan(lambda a, b: a @ b, outputs_info=[A], non_sequences=[B], n_steps=2) fg = FunctionGraph([A, B], [out]) - compare_jax_and_py(fg, [np.eye(3), np.eye(3)]) + compare_jax_and_py(fg, [np.eye(3), np.eye(3)], jax_mode="JAX") def test_dynamic_sequence_length(): diff --git a/tests/link/jax/test_sparse.py b/tests/link/jax/test_sparse.py index 0c377bdcd8..c53aa301af 100644 --- a/tests/link/jax/test_sparse.py +++ b/tests/link/jax/test_sparse.py @@ -51,7 +51,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op): dot_pt = op(x_pt, y_pt) fgraph = FunctionGraph(inputs, [dot_pt]) - compare_jax_and_py(fgraph, test_values) + compare_jax_and_py(fgraph, test_values, jax_mode="JAX") def test_sparse_dot_non_const_raises(): diff --git a/tests/link/jax/test_tensor_basic.py b/tests/link/jax/test_tensor_basic.py index afa4191b9d..0ee4a236d9 100644 --- a/tests/link/jax/test_tensor_basic.py +++ b/tests/link/jax/test_tensor_basic.py @@ -74,7 +74,7 @@ def test_arange_of_shape(): x = vector("x") out = ptb.arange(1, x.shape[-1], 2) fgraph = FunctionGraph([x], [out]) - compare_jax_and_py(fgraph, [np.zeros((5,))]) + compare_jax_and_py(fgraph, [np.zeros((5,))], jax_mode="JAX") def test_arange_nonconcrete(): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index bb1958f43e..93035f52f4 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -7,13 +7,15 @@ import pytensor.tensor.basic as ptb from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function -from pytensor.compile.mode import get_mode +from pytensor.compile.mode import PYTORCH, Mode from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config +from pytensor.graph import RewriteDatabaseQuery from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.ifelse import ifelse +from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise from pytensor.tensor import alloc, arange, as_tensor, empty, eye from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -22,8 +24,13 @@ torch = pytest.importorskip("torch") -pytorch_mode = get_mode("PYTORCH") -py_mode = get_mode("FAST_COMPILE") +optimizer = RewriteDatabaseQuery( + # While we don't have a PyTorch implementation of Blockwise + include=["local_useless_unbatched_blockwise"], + exclude=PYTORCH._optimizer.exclude, +) +pytorch_mode = Mode(linker=PytorchLinker(), optimizer=optimizer) +py_mode = Mode(linker="py", optimizer=None) def compare_pytorch_and_py( @@ -220,7 +227,7 @@ def test_alloc_and_empty(): assert res.dtype == torch.float32 v = vector("v", shape=(3,), dtype="float64") - out = alloc(v, (dim0, dim1, 3)) + out = alloc(v, dim0, dim1, 3) compare_pytorch_and_py( FunctionGraph([v, dim1], [out]), [np.array([1, 2, 3]), np.array(7)], diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index f651b14e0a..9488a9f688 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1056,7 +1056,6 @@ def test_big_fusion(self): for node in dlogp.maker.fgraph.toposort() ) - @pytest.mark.xfail(reason="Fails due to #1244") def test_add_mul_fusion_precedence(self): """Test that additions and multiplications are "fused together" before a `Composite` `Op` is introduced. This fusion is done by canonicalization diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index a556e3a275..9134b29b65 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -27,8 +27,8 @@ def fetch_seed(pseed=None): None, which is equivalent to seeding with a random seed. Useful for seeding RandomState or Generator objects. - >>> rng = np.random.RandomState(unittest_tools.fetch_seed()) - >>> rng = np.random.default_rng(unittest_tools.fetch_seed()) + >>> rng = np.random.RandomState(fetch_seed()) + >>> rng = np.random.default_rng(fetch_seed()) """ seed = pseed or config.unittests__rseed