-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
I'm having trouble passing a Distribution object to a function when its argument is annotated with donate_argnums:
import jax
import jax.numpy as jnp
from distrax import Categorical
x = Categorical(logits=jnp.zeros(4))
def foo(x):
return x.logits
f = jax.jit(foo, donate_argnums=0).trace(x).lower().compile()
f(x)Here's the traceback:
Traceback (most recent call last):
File "foo.py", line 13, in <module>
f(x)
File ".venv/lib64/python3.11/site-packages/jax/_src/stages.py", line 849, in __call__
return self._call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib64/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 3277, in aot_cache_miss
outs, out_flat, args_flat = stages.Compiled.call(params, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".venv/lib64/python3.11/site-packages/jax/_src/stages.py", line 813, in call
raise TypeError('\n'.join(msg))
TypeError: Function compiled with input pytree does not match the input pytree it was called with. There are 1 mismatches, including:
* at args[0], seen <class 'distrax._src.distributions.categorical.Categorical'> with pytree metadata ([<class 'int'>, ArgInfo(_aval=ShapedArray(float32[4]), donated=True)], [False, False], PyTreeDef({'_dtype': *, '_logits': *, '_probs': None})) but now given <class 'distrax._src.distributions.categorical.Categorical'> with pytree metadata ([<class 'int'>, <jax._src.tree_util.Leaf object at 0x14e3749615d0>], [False, False], PyTreeDef({'_dtype': *, '_logits': *, '_probs': None})), so the pytree node metadata does not match
The code works fine if I explicitly flatten the Distribution and unflatten inside jit:
def foo(leaves, treedef):
x = jax.tree.unflatten(treedef, leaves)
return x.logits
leaves, treedef = jax.tree.flatten(x)
f = (
jax.jit(foo, donate_argnums=0, static_argnums=1)
.trace(leaves, treedef)
.lower()
.compile()
)
f(leaves)But this is quite cumbersome. I think the correct approach is to also mark jax.stages.ArgInfo as a form of "Jax data" inside distrax._src.utils.jittable._is_jax_data
Indeed, after making this change, the original code works as expected.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels