How to write a partial vmap with a user PyTree #10210
Unanswered
jecampagne
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Just treat your pytree as a container. import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Params:
def __init__(self, a, b):
self._a = a
self._b = b
def __repr__(self):
return f"Params(x={self._a}, y={self._b})"
@property
def a(self):
return self._a
@property
def b(self):
return self._b
def tree_flatten(self):
children = (self._a, self._b)
aux_data = None
return (children, aux_data)
@classmethod
def tree_unflatten(cls, aux_data, children):
return cls(*children)
def g(p,x):
return (p.a)**2 + (p.b)*x
my_params = Params(3.,1.)
print(jax.vmap(g, in_axes=(Params(None, 0), None))(Params(3., jnp.array([1.,2.,3.])), 10))
# [19. 29. 39.] |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, let me write some snippets to introduce my problem on applying
vmap
on a user PyTree.So, first let me practice with a dictionary:
this gives respectively
So far so good, now let me define my PyTree
and play with it the same game as with the dictionnary
gives
ok, but how to write the equivalent of
in_axes=({"a": None, "b": 0},None)
?Thanks.
Beta Was this translation helpful? Give feedback.
All reactions