-
OverviewI have a very complex pytree of model parameters, which I will call model A, and this is a massive nested dict of tensors of various shapes. What I want to do is copy the pytree structure of A but fill it with randomly sampled tensors that have the same shape as the tensors in A. This way I can do a tree_multimap with A and this copy. My IssueMy idea is to traverse the tree and generate the random sample for each tensor leaf, but I am having trouble propagating the PRNGKey during the traversal. I could use something like
But this would mean that the random tensors are going to be very similar to each other (even though some will be different shapes) because I cannot split the key during the traversal. I though about using jax.lax.scan because it will allow me to carry a state but when I pass A as the 'xs' parameter (refering to the documentation) it complains about leading axes not being the same. QuestionDo you know how I can cleanly generate copy_of_A and split the key after each sampling? If I am currently including this as part of a larger function that I am calling jit on, how can I do this correctly, if possible? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
I have been using something like this: def random_split_like_tree(rng_key, target=None, treedef=None):
if treedef is None:
treedef = jax.tree_structure(target)
keys = jax.random.split(rng_key, treedef.num_leaves)
return jax.tree_unflatten(treedef, keys)
def tree_random_normal_like(rng_key, target):
keys_tree = random_split_like_tree(rng_key, target)
return jax.tree_multimap(
lambda l, k: jax.random.normal(k, l.shape, l.dtype),
target,
keys_tree,
) |
Beta Was this translation helpful? Give feedback.
I have been using something like this: