Skip to content
Discussion options

You must be logged in to vote

Searching around for a "mapreduce" primitive in jax lead me to this thread. Looks like the answer is in fact to use vmap with the reduction applied, though apparently scan can work too under some circumstances.... But in any case jit is crucial to optimize the XLA computations.

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
2 replies
@aeftimia
Comment options

@soraros
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by aeftimia
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants