Skip to content
Discussion options

You must be logged in to vote

pmap has a similar API as vmap, so I suspect what you want is this:

pmap_mm = pmap(vmap_mv, in_axes=(None, 1), out_axes=1)
pmap_mm(aaa, aaa)

That is, don't use pmap in addition to vmap, use pmap in place of vmap.

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@oconnoob
Comment options

@jakevdp
Comment options

@oconnoob
Comment options

@jakevdp
Comment options

@oconnoob
Comment options

Answer selected by oconnoob
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants