Skip to content
Discussion options

You must be logged in to vote

Try:

def batched_gather_nd(params, indices):
    return vmap(unbatched_gather_nd, (None, 0), 0)(params, indices)

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by bigyankarki
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants