Skip to content

Unexpected behaviour when passing a Distribution to a function with donate_argnums #308

@adzcai

Description

@adzcai

I'm having trouble passing a Distribution object to a function when its argument is annotated with donate_argnums:

import jax
import jax.numpy as jnp
from distrax import Categorical

x = Categorical(logits=jnp.zeros(4))

def foo(x):
    return x.logits

f = jax.jit(foo, donate_argnums=0).trace(x).lower().compile()
f(x)

Here's the traceback:

Traceback (most recent call last):
  File "foo.py", line 13, in <module>
    f(x)
  File ".venv/lib64/python3.11/site-packages/jax/_src/stages.py", line 849, in __call__
    return self._call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib64/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 3277, in aot_cache_miss
    outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib64/python3.11/site-packages/jax/_src/stages.py", line 813, in call
    raise TypeError('\n'.join(msg))
TypeError: Function compiled with input pytree does not match the input pytree it was called with. There are 1 mismatches, including:
    * at args[0], seen <class 'distrax._src.distributions.categorical.Categorical'> with pytree metadata ([<class 'int'>, ArgInfo(_aval=ShapedArray(float32[4]), donated=True)], [False, False], PyTreeDef({'_dtype': *, '_logits': *, '_probs': None})) but now given <class 'distrax._src.distributions.categorical.Categorical'> with pytree metadata ([<class 'int'>, <jax._src.tree_util.Leaf object at 0x14e3749615d0>], [False, False], PyTreeDef({'_dtype': *, '_logits': *, '_probs': None})), so the pytree node metadata does not match

The code works fine if I explicitly flatten the Distribution and unflatten inside jit:

def foo(leaves, treedef):
    x = jax.tree.unflatten(treedef, leaves)
    return x.logits


leaves, treedef = jax.tree.flatten(x)
f = (
    jax.jit(foo, donate_argnums=0, static_argnums=1)
    .trace(leaves, treedef)
    .lower()
    .compile()
)
f(leaves)

But this is quite cumbersome. I think the correct approach is to also mark jax.stages.ArgInfo as a form of "Jax data" inside distrax._src.utils.jittable._is_jax_data

Indeed, after making this change, the original code works as expected.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions