diff --git a/doc/conf.py b/doc/conf.py index e10dcffb90..48d81730ba 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,6 +38,7 @@ "jax": ("https://jax.readthedocs.io/en/latest", None), "numpy": ("https://numpy.org/doc/stable", None), "torch": ("https://pytorch.org/docs/stable", None), + "equinox": ("https://docs.kidger.site/equinox/", None), } needs_sphinx = "3" diff --git a/doc/environment.yml b/doc/environment.yml index 7b564e8fb0..5b1f8790dc 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -25,4 +25,4 @@ dependencies: - ablog - pip - pip: - - -e .. + - -e ..[jax] diff --git a/doc/library/index.rst b/doc/library/index.rst index e9b362f8db..70506f6120 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -61,6 +61,13 @@ Convert to Variable .. autofunction:: pytensor.as_symbolic(...) +Wrap JAX functions +================== + +.. autofunction:: as_jax_op(...) + + Alias for :func:`pytensor.link.jax.ops.as_jax_op` + Debug ===== diff --git a/pyproject.toml b/pyproject.toml index 3eeb7a7c2b..82e13cb648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [build-system] requires = [ - "setuptools>=59.0.0", - "cython", - "numpy>=1.17.0", - "versioneer[toml]==0.29", + "setuptools>=59.0.0", + "cython", + "numpy>=1.17.0", + "versioneer[toml]==0.29", ] build-backend = "setuptools.build_meta" @@ -17,44 +17,44 @@ readme = "README.rst" license = "BSD-3-Clause" license-files = ["LICENSE.txt"] classifiers = [ - "Development Status :: 6 - Mature", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "Intended Audience :: Developers", - "Programming Language :: Python", - "Topic :: Software Development :: Code Generators", - "Topic :: Software Development :: Compilers", - "Topic :: Scientific/Engineering :: Mathematics", - "Operating System :: Microsoft :: Windows", - "Operating System :: POSIX", - "Operating System :: Unix", - "Operating System :: MacOS", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", + "Development Status :: 6 - Mature", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Topic :: Software Development :: Code Generators", + "Topic :: Software Development :: Compilers", + "Topic :: Scientific/Engineering :: Mathematics", + "Operating System :: Microsoft :: Windows", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] keywords = [ - "pytensor", - "math", - "numerical", - "symbolic", - "blas", - "numpy", - "autodiff", - "differentiation", + "pytensor", + "math", + "numerical", + "symbolic", + "blas", + "numpy", + "autodiff", + "differentiation", ] dependencies = [ - "setuptools>=59.0.0", - "scipy>=1,<2", - "numpy>=1.17.0", - "filelock>=3.15", - "etuples", - "logical-unification", - "miniKanren", - "cons", + "setuptools>=59.0.0", + "scipy>=1,<2", + "numpy>=1.17.0", + "filelock>=3.15", + "etuples", + "logical-unification", + "miniKanren", + "cons", ] [project.urls] @@ -70,16 +70,16 @@ pytensor-cache = "pytensor.bin.pytensor_cache:main" complete = ["pytensor[jax]", "pytensor[numba]"] development = ["pytensor[complete]", "pytensor[tests]", "pytensor[rtd]"] tests = [ - "pytest", - "pre-commit", - "pytest-cov>=2.6.1", - "coverage>=5.1", - "pytest-benchmark", - "pytest-mock", - "pytest-sphinx", + "pytest", + "pre-commit", + "pytest-cov>=2.6.1", + "coverage>=5.1", + "pytest-benchmark", + "pytest-mock", + "pytest-sphinx", ] rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot"] -jax = ["jax", "jaxlib"] +jax = ["jax", "jaxlib", "equinox"] numba = ["numba>=0.57", "llvmlite"] [tool.setuptools.packages.find] @@ -91,16 +91,16 @@ pytensor = ["py.typed"] [tool.coverage.run] omit = [ - "pytensor/_version.py", - "tests/*", - "pytensor/assert_op.py", - "pytensor/graph/opt.py", - "pytensor/graph/opt_utils.py", - "pytensor/graph/optdb.py", - "pytensor/graph/kanren.py", - "pytensor/graph/unify.py", - "pytensor/link/jax/jax_linker.py", - "pytensor/link/jax/jax_dispatch.py", + "pytensor/_version.py", + "tests/*", + "pytensor/assert_op.py", + "pytensor/graph/opt.py", + "pytensor/graph/opt_utils.py", + "pytensor/graph/optdb.py", + "pytensor/graph/kanren.py", + "pytensor/graph/unify.py", + "pytensor/link/jax/jax_linker.py", + "pytensor/link/jax/jax_dispatch.py", ] branch = true relative_files = true @@ -130,11 +130,24 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] +select = [ + "C", + "E", + "F", + "I", + "UP", + "W", + "RUF", + "PERF", + "PTH", + "ISC", + "T20", + "NPY201", +] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ - # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead - "B905", + # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead + "B905", ] diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 3c925ac2f2..a7f9aa8058 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.compile.builders import OpFromGraph +try: + import pytensor.link.jax.ops + from pytensor.link.jax.ops import as_jax_op +except ImportError as e: + import_error_as_jax_op = e + + def as_jax_op(*args, **kwargs): + raise ImportError( + "JAX and/or equinox are not installed. Install them" + " to use this function: pip install pytensor[jax]" + ) from import_error_as_jax_op + # isort: on diff --git a/pytensor/link/jax/ops.py b/pytensor/link/jax/ops.py new file mode 100644 index 0000000000..da492fc71b --- /dev/null +++ b/pytensor/link/jax/ops.py @@ -0,0 +1,295 @@ +"""Convert a jax function to a pytensor compatible function.""" + +import logging +from collections.abc import Sequence +from functools import wraps + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np + +import pytensor.tensor as pt +from pytensor.gradient import DisconnectedType +from pytensor.graph import Apply, Op, Variable +from pytensor.link.jax.dispatch import jax_funcify + + +log = logging.getLogger(__name__) + + +class JAXOp(Op): + """ + JAXOp is a PyTensor Op that wraps a JAX function, providing both forward computation and reverse-mode differentiation (via the VJPJAXOp class). + + Parameters + ---------- + input_types : list + A list of PyTensor types for each input variable. + output_types : list + A list of PyTensor types for each output variable. + flat_func : callable + The JAX function that computes outputs from inputs. + name : str, optional + A custom name for the Op instance. If provided, the class name will be + updated accordingly. + + Example + ------- + This example defines a simple function that sums the input array with a dynamic shape. + + >>> import numpy as np + >>> import jax + >>> import jax.numpy as jnp + >>> from pytensor.tensor import TensorType + >>> + >>> # Create the jax function that sums the input array. + >>> def sum_function(x, y): + ... return jnp.sum(x + y) + >>> + >>> # Create the input and output types, input has a dynamic shape. + >>> input_type = TensorType("float32", shape=(None,)) + >>> output_type = TensorType("float32", shape=(1,)) + >>> + >>> # Instantiate a JAXOp + >>> op = JAXOp( + ... [input_type, input_type], [output_type], sum_function, name="DummyJAXOp" + ... ) + >>> # Define symbolic input variables. + >>> x = pt.tensor("x", dtype="float32", shape=(2,)) + >>> y = pt.tensor("y", dtype="float32", shape=(2,)) + >>> # Compile a PyTensor function. + >>> result = op(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print( + ... f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array(14., dtype=float32)] + >>> + >>> # Compute the gradient of op(x, y) with respect to x. + >>> g = pt.grad(result[0], x) + >>> grad_f = pytensor.function([x, y], [g]) + >>> print( + ... grad_f( + ... np.array([2.0, 3.0], dtype=np.float32), + ... np.array([4.0, 5.0], dtype=np.float32), + ... ) + ... ) + [array([1., 1.], dtype=float32)] + """ + + __props__ = ("input_types", "output_types", "jax_func", "name") + + def __init__(self, input_types, output_types, jax_func, name=None): + self.input_types = tuple(input_types) + self.output_types = tuple(output_types) + self.jax_func = jax_func + self.jitted_func = jax.jit(jax_func) + self.name = name + super().__init__() + + def __repr__(self): + base = self.__class__.__name__ + if self.name is not None: + base = f"{base}{self.name}" + props = list(self.__props__) + props.remove("name") + props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props) + return f"{base}({props})" + + def make_node(self, *inputs: Variable) -> Apply: + for input, expected in zip(inputs, self.input_types, strict=True): + if input.type != expected: + raise TypeError( + f"Expected input of type {expected}, got {input.type} instead." + ) + outputs = [typ() for typ in self.output_types] + return Apply(self, inputs, outputs) + + def perform(self, node, inputs, outputs): + results = self.jitted_func(*inputs) + if len(results) != len(outputs): + raise ValueError( + f"Expected {len(outputs)} outputs from jax function, got {len(results)}." + ) + for i, result in enumerate(results): + outputs[i][0] = np.array(result, self.output_types[i].dtype) + + def perform_jax(self, *inputs): + return self.jitted_func(*inputs) + + def grad(self, inputs, output_gradients): + wrt_index = [] + for i, out in output_gradients: + if not isinstance(out.type, DisconnectedType): + wrt_index.append(i) + + num_inputs = len(inputs) + + def vjp_jax_op(*args): + inputs = args[:num_inputs] + covectors = args[num_inputs:] + assert len(covectors) == len(wrt_index) + + def func_restricted(*inputs): + out = self.jax_func(*inputs) + return [out[i] for i in wrt_index] + + _primals, vjp_fn = jax.vjp(func_restricted, *inputs) + return vjp_fn(covectors) + + op = JAXOp( + self.input_types + [self.output_types[i] for i in wrt_index], + [self.input_types[i] for i in range(num_inputs)], + vjp_jax_op, + name="VJP" + (self.name if self.name is not None else ""), + ) + + return op(inputs + [output_gradients[i] for i in wrt_index]) + + +def as_jax_op(jaxfunc): + """Return a Pytensor-compatible function from a JAX jittable function. + + This decorator wraps a JAX function so that it accepts and returns `pytensor.Variable` + objects. The JAX-jittable function can accept any + nested python structure (a `Pytree + `_) as input, and might return + any nested Python structure. + + Parameters + ---------- + jaxfunc : Callable + A JAX function to be wrapped. + + Returns + ------- + Callable + A function that wraps the given JAX function so that it can be called with + pytensor.Variable inputs and returns pytensor.Variable outputs. + + Examples + -------- + + >>> import jax.numpy as jnp + >>> import pytensor.tensor as pt + >>> @as_jax_op + ... def add(x, y): + ... return jnp.add(x, y) + >>> x = pt.scalar("x") + >>> y = pt.scalar("y") + >>> result = add(x, y) + >>> f = pytensor.function([x, y], [result]) + >>> print(f(1, 2)) + [array(3.)] + + Notes + ----- + The function is based on a blog post by Ricardo Vieira and Adrian Seyboldt, + available at + `pymc-labls.io `__. + To accept functions and non pytensor variables as input, the function make use + of :func:`equinox.partition` and :func:`equinox.combine` to split and combine the + variables. Shapes are inferred using + :func:`pytensor.compile.builders.infer_shape` and :func:`jax.eval_shape`. + + """ + name = jaxfunc.__name__ + + @wraps(jaxfunc) + def func(*args, **kwargs): + # Partition inputs into dynamic pytensor variables, wrapped functions and + # static variables. + # Static variables don't take part in the graph. + + pt_vars, static_vars = eqx.partition( + (args, kwargs), lambda x: isinstance(x, pt.Variable) + ) + + # Flatten the input dictionary. + pt_vars_flat, pt_vars_treedef = jax.tree.flatten( + pt_vars, + ) + pt_types = [var.type for var in pt_vars_flat] + + # We need to figure out static shapes so that we can figure + # out the output types. + input_shapes = [var.type.shape for var in pt_vars_flat] + resolved_input_shapes = [] + for var, shape in zip(pt_vars_flat, input_shapes, strict=True): + if any(s is None for s in shape): + _, shape = pt.basic.infer_static_shape(var.shape) + if any(s is None for s in shape): + raise ValueError( + f"Input variable {var} has a shape with undetermined " + "shape. Please provide inputs with fully determined shapes " + "by calling pt.specify_shape." + ) + resolved_input_shapes.append(shape) + + # Figure out output types using jax.eval_shape. + extra_output_storage = {} + + def wrap_jaxfunc(args): + vars = jax.tree.unflatten(pt_vars_treedef, args) + args, kwargs = eqx.combine( + vars, + static_vars, + ) + outputs = jaxfunc(*args, **kwargs) + output_vals, output_static = eqx.partition(outputs, eqx.is_array) + extra_output_storage["output_static"] = output_static + outputs_flat, output_treedef = jax.tree.flatten(output_vals) + extra_output_storage["output_treedef"] = output_treedef + return outputs_flat + + dummy_inputs = [ + jnp.ones(shape, dtype=var.type.dtype) + for var, shape in zip(pt_vars_flat, resolved_input_shapes, strict=True) + ] + + output_shapes_flat = jax.eval_shape(wrap_jaxfunc, dummy_inputs) + output_treedef = extra_output_storage["output_treedef"] + output_static = extra_output_storage["output_static"] + pt_output_types = [ + pt.TensorType(dtype=var.dtype, shape=var.shape) + for var in output_shapes_flat + ] + + def flat_func(*flat_vars): + vars = jax.tree.unflatten(pt_vars_treedef, flat_vars) + args, kwargs = eqx.combine( + vars, + static_vars, + ) + outputs = jaxfunc(*args, **kwargs) + output_vals, _ = eqx.partition(outputs, eqx.is_array) + outputs_flat, _ = jax.tree.flatten(output_vals) + return outputs_flat + + op_instance = JAXOp( + pt_types, + pt_output_types, + flat_func, + name=name, + ) + + # 8. Execute the op and unflatten the outputs. + output_flat = op_instance(*pt_vars_flat) + if not isinstance(output_flat, Sequence): + output_flat = [output_flat] + outvars = jax.tree.unflatten(output_treedef, output_flat) + outvars = eqx.combine(outvars, output_static) + + return outvars + + return func + + +@jax_funcify.register(JAXOp) +def jax_op_funcify(op, **kwargs): + return op.perform_jax diff --git a/tests/link/jax/test_as_jax_op.py b/tests/link/jax/test_as_jax_op.py new file mode 100644 index 0000000000..62fd270032 --- /dev/null +++ b/tests/link/jax/test_as_jax_op.py @@ -0,0 +1,474 @@ +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor import as_jax_op, config, grad +from pytensor.graph.fg import FunctionGraph +from pytensor.link.jax.ops import JAXOp +from pytensor.scalar import all_types +from pytensor.tensor import TensorType, tensor +from tests.link.jax.test_basic import compare_jax_and_py + + +def test_two_inputs_single_output(): + rng = np.random.default_rng(1) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y) + + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,))], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_two_inputs_tuple_output(): + rng = np.random.default_rng(2) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return jax.nn.sigmoid(x + y), y * 2 + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + # must_be_device_array is False, because the with disabled jit compilation, + # inputs are not automatically transformed to jax.Array anymore + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x, y) + grad_out = grad(pt.sum(out1 + out2), [x, y]) + fg = FunctionGraph([x, y], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_two_inputs_list_output_one_unused_output(): + # One output is unused, to test whether the wrapper can handle DisconnectedType + rng = np.random.default_rng(3) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return [jax.nn.sigmoid(x + y), y * 2] + + # Test with as_jax_op decorator + out, _ = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out), [x, y]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out, _ = jax_op(x, y) + grad_out = grad(pt.sum(out), [x, y]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_single_input_tuple_output(): + rng = np.random.default_rng(4) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x * 2 + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_scalar_input_tuple_output(): + rng = np.random.default_rng(5) + x = tensor("x", shape=()) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return jax.nn.sigmoid(x), x + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type], + [TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_single_input_list_output(): + rng = np.random.default_rng(6) + x = tensor("x", shape=(2,)) + test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)] + + def f(x): + return [jax.nn.sigmoid(x), 2 * x] + + # Test with as_jax_op decorator + out1, out2 = as_jax_op(f)(x) + grad_out = grad(pt.sum(out1), [x]) + + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + # Test direct JAXOp usage, with unspecified output shapes + jax_op = JAXOp( + [x.type], + [ + TensorType(config.floatX, shape=(None,)), + TensorType(config.floatX, shape=(None,)), + ], + f, + ) + out1, out2 = jax_op(x) + grad_out = grad(pt.sum(out1), [x]) + fg = FunctionGraph([x], [out1, out2, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_pytree_input_tuple_output(): + rng = np.random.default_rng(7) + x = tensor("x", shape=(2,)) + y = tensor("y", shape=(2,)) + y_tmp = {"y": y, "y2": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0] + + # Test with as_jax_op decorator + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + +def test_pytree_input_pytree_output(): + rng = np.random.default_rng(8) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y) + + # Test with as_jax_op decorator + out = f(x, y_tmp) + grad_out = grad(pt.sum(out[1]["b"][0]), [x, y]) + + fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False) + + +def test_pytree_input_with_non_graph_args(): + rng = np.random.default_rng(9) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(1,)) + y_tmp = {"a": y, "b": [y**2]} + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f(x, y, depth, which_variable): + if which_variable == "x": + var = x + elif which_variable == "y": + var = y["a"] + y["b"][0] + else: + return "Unsupported argument" + for _ in range(depth): + var = jax.nn.sigmoid(var) + return var + + # Test with as_jax_op decorator + # arguments depth and which_variable are not part of the graph + out = f(x, y_tmp, depth=3, which_variable="x") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + out = f(x, y_tmp, depth=7, which_variable="y") + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out[0], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + out = f(x, y_tmp, depth=10, which_variable="z") + assert out == "Unsupported argument" + + +def test_unused_matrix_product(): + # A matrix output is unused, to test whether the wrapper can handle a + # DisconnectedType with a larger dimension. + + rng = np.random.default_rng(10) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + def f(x, y): + return x[:, None] @ y[None], jnp.exp(x) + + # Test with as_jax_op decorator + out = as_jax_op(f)(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [ + TensorType(config.floatX, shape=(3, 3)), + TensorType(config.floatX, shape=(3,)), + ], + f, + ) + out = jax_op(x, y) + grad_out = grad(pt.sum(out[1]), [x]) + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_unknown_static_shape(): + rng = np.random.default_rng(11) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape + + def f(x, y): + return x * jnp.ones(3) + + out = as_jax_op(f)(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + # Test direct JAXOp usage + jax_op = JAXOp( + [x.type, y.type], + [TensorType(config.floatX, shape=(None,))], + f, + ) + out = jax_op(x_cumsum, y) + grad_out = grad(pt.sum(out), [x]) + fg = FunctionGraph([x, y], [out, *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + +def test_nested_functions(): + rng = np.random.default_rng(13) + x = tensor("x", shape=(3,)) + y = tensor("y", shape=(3,)) + test_values = [ + rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y) + ] + + @as_jax_op + def f_internal(y): + def f_ret(t): + return y + t + + def f_ret2(t): + return f_ret(t) + t**2 + + return f_ret, y**2 * jnp.ones(1), f_ret2 + + f, y_pow, f2 = f_internal(y) + + @as_jax_op + def f_outer(x, dict_other): + f, y_pow = dict_other["func"], dict_other["y"] + return x * jnp.ones(3), f(x) * y_pow + + out = f_outer(x, {"func": f, "y": y_pow}) + grad_out = grad(pt.sum(out[1]), [x]) + + fg = FunctionGraph([x, y], [out[1], *grad_out]) + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + +class TestDtypes: + @pytest.mark.parametrize("in_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("out_dtype", list(map(str, all_types))) + def test_different_in_output(self, in_dtype, out_dtype): + x = tensor("x", shape=(3,), dtype=in_dtype) + y = tensor("y", shape=(3,), dtype=in_dtype) + + if "int" in in_dtype: + test_values = [ + np.random.randint(0, 10, size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + else: + test_values = [ + np.random.normal(size=(inp.type.shape)).astype(inp.type.dtype) + for inp in (x, y) + ] + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(out_dtype) + + out = f(x, y) + assert out.dtype == out_dtype + + if "float" in in_dtype and "float" in out_dtype: + grad_out = grad(out[0], [x, y]) + assert grad_out[0].dtype == in_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values) + + @pytest.mark.parametrize("in1_dtype", list(map(str, all_types))) + @pytest.mark.parametrize("in2_dtype", list(map(str, all_types))) + def test_test_different_inputs(self, in1_dtype, in2_dtype): + x = tensor("x", shape=(3,), dtype=in1_dtype) + y = tensor("y", shape=(3,), dtype=in2_dtype) + + if "int" in in1_dtype: + test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)] + else: + test_values = [np.random.normal(size=(3,)).astype(x.type.dtype)] + if "int" in in2_dtype: + test_values.append(np.random.randint(0, 10, size=(3,)).astype(y.type.dtype)) + else: + test_values.append(np.random.normal(size=(3,)).astype(y.type.dtype)) + + @as_jax_op + def f(x, y): + out = jnp.add(x, y) + return jnp.real(out).astype(in1_dtype) + + out = f(x, y) + assert out.dtype == in1_dtype + + if "float" in in1_dtype and "float" in in2_dtype: + # In principle, the gradient should also be defined if the second input is + # an integer, but it doesn't work for some reason. + grad_out = grad(out[0], [x]) + assert grad_out[0].dtype == in1_dtype + fg = FunctionGraph([x, y], [out, *grad_out]) + else: + fg = FunctionGraph([x, y], [out]) + + fn, _ = compare_jax_and_py(fg, test_values) + + with jax.disable_jit(): + fn, _ = compare_jax_and_py(fg, test_values)