Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
There are several pytree operations that, IMO, would be very useful if present in jax.tree_util:
Say I have 2 pytrees with different structures:
tree1={"a": 2, "b": {"c": 1}}
andtree2={"a": 3, "d": 5}
. I would like to do atree_map
on these trees with some functionf
so that the map result is{"a": f(2, 3), "b": {"c": f(1, None)}, "d": f(None, 5)}
, i.e. we passNone
tof
for leafs that are missing in one of the trees. This would, for instance, simplify model surgery, allowing to combine parameters from multiple models.Are there any existing libraries built on top of jax.tree_util that can do something similar? Or, perhaps, there is a reasonably easy way to build something like that from existing primitives (I can't think of any)?
Beta Was this translation helpful? Give feedback.
All reactions