Trying to implement some tf functions into jax #10467
Answered
by
YouJiacheng
bigyankarki
asked this question in
Q&A
-
I have this function in tensorflow:
And, I have refactored the code to Jax as following:
However, I keep on getting an error:
The inputs are: Output so far:
|
Beta Was this translation helpful? Give feedback.
Answered by
YouJiacheng
Apr 28, 2022
Replies: 2 comments
-
Try: def batched_gather_nd(params, indices):
return vmap(unbatched_gather_nd, (None, 0), 0)(params, indices) |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
bigyankarki
-
This seems to work for me:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Try: