Clean way to map randomness over a pytree #7212
Unanswered
greeneggsandyaml
asked this question in
Q&A
Replies: 1 comment
-
I imagine the best approach would be to use import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
d = {'a': {'aa': jnp.zeros(4), 'ab': jnp.zeros(4)},
'b': {'aa': jnp.zeros(5), 'ab': jnp.zeros(5)}}
key, *subkeys = random.split(key, 1 + len(jax.tree_leaves(d)))
subkeys = jax.tree_unflatten(jax.tree_structure(d), subkeys)
result = jax.tree_map(lambda x, key: x + jax.random.normal(key, x.shape), d, subkeys)
print(result)
# {'a': {'aa': DeviceArray([-0.9407592 , -1.4780389 , -1.0911523 , -0.23124945], dtype=float32),
# 'ab': DeviceArray([ 0.6453762 , -1.1911646 , 0.82499367, -0.35288557], dtype=float32)},
# 'b': {'aa': DeviceArray([ 0.8250121 , 0.8801492 , -1.1567167 , -0.03485736, 0.41830325], dtype=float32),
# 'ab': DeviceArray([-2.46225 , 1.622537 , 1.429637 , -0.42747816, 1.1972864 ], dtype=float32)}} |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello authors,
Thank you for this wonderful library. I have a quick question: what is the cleanest way to map a random function over the leaves of a pytree? For example, let's say I want to add some N(0,1) Gaussian random noise to each parameter in a dict of parameters.
Concretely, suppose we have the following:
I imagine that there is a way of doing something like:
or something similarly clean. Alternatively, if there's a better way of doing this, I'd love to know.
I figured out that it's possible to do
and this works, but there must be a better way. I feel like I'm missing something obvious here, so I figured that I should ask.
Thanks for your help!
Beta Was this translation helpful? Give feedback.
All reactions