Skip to content
Discussion options

You must be logged in to vote

It sounds like what you want is similar to the semantics of jnp.vectorize? It might look something like this, using the variables you defined:

from functools import partial

@partial(jnp.vectorize, signature='(m),()->(n)')
def fn(x, c):
  return A @ x + c

out = fn(x, c)

This is implemented via nested calls to vmap, as in your approach. But if that's not the API you want, I think your wrapper function looks like a great solution.

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by andrewwarrington
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants