Updating slices in a pytree with slices from another pytree of the same shape #14686
Unanswered
StoneT2000
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think you want this: jux_action: JuxAction = jax.tree_map(lambda x, y : x.at[:, 1].set(y[:, 1]), jux_action_0, jux_action_1) or if you want to use jux_action: JuxAction = jax.tree_map(lambda x, y : x.at[:, 1].set(y.at[:, 1].get()), jux_action_0, jux_action_1) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I currently have a PyTree with the following shape
Notably each leaf has a shared dimensions of (10, 2) = (batch_size, team_id) (this is for a game environment with two players).
Suppose I have two of these,
jux_action_0, jux_action_1
, however the following doesn't workI was hoping this would auto set the 2nd element in the 2nd dimension with what is stored in
jux_action_1
and leave the rest as whatever was injux_action_0
. However I get the following errorThis seems to be because the return value of the lambda passed to tree_map is the output of
.set
. Is there another way to set the value of a slice with another slice and return it in tree_map?Beta Was this translation helpful? Give feedback.
All reactions