-
(Edit: self-contained MWE at the bottom) I'm working on a custom language which includes a set of primitives which I desugar to JAX understood constructs. I started testing programs which accept def hidden_markov_model(key, T, config):
z0 = initial_position(config)
key, z = trace("z", kernel)(key, (T, z0, config))
return key, z Here,
When I attempt to execute this through some of my API functions, I get the following error:
E.g. in the I handle the primitives later (using a special transformation) which removes them from code — so really they are just sort of used as indicator sites for the later transformation. More generally: if an object is a Edit: here is a completely self-contained MWE demonstrating the thing I'm running into: import jax
@jax.tree_util.register_pytree_node_class
class Foo:
def tree_flatten(self):
return (), ()
@classmethod
def tree_unflatten(cls, xs, data):
return Foo()
g_p = jax.core.Primitive("some_g")
def g(*args):
return g_p.bind(*args)
def bar(foo):
return g(foo)
jaxpr = jax.make_jaxpr(bar)(Foo())
print(jaxpr)
Should my "primitive wrapper" function |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
To support this, I did exactly what I indicated with my question -- flatten/unflatten across A simple idiom: attach the |
Beta Was this translation helpful? Give feedback.
To support this, I did exactly what I indicated with my question -- flatten/unflatten across
bind
boundaries.A simple idiom: attach the
tree_form
fromjax.tree_util.tree_flatten
as metadata to the primitive (as a keyword argument tobind
).