Skip to content
Discussion options

You must be logged in to vote

@davisyoshida ’s solution is correct, but I suggest using closure(or partial) to simplify the code.

def foo(inputs, input2, input3)
    @jax.vmap
    @jax.vmap
    @jax.vmap
    def f(inputs_unmapped):
        return MyFunction(inputs_unmapped, input2, input3)
    return f(inputs)

Replies: 3 comments 7 replies

Comment options

You must be logged in to vote
3 replies
@dominicpasquali
Comment options

@YouJiacheng
Comment options

@dominicpasquali
Comment options

Comment options

You must be logged in to vote
4 replies
@davisyoshida
Comment options

@dominicpasquali
Comment options

@YouJiacheng
Comment options

@davisyoshida
Comment options

Answer selected by dominicpasquali
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants