Skip to content
Discussion options

You must be logged in to vote

@joeryjoery Emmm, I don't know why we need a nested tree_map here. (Assume: pack all arguments into single pytree).

import jax
import jax.numpy as jnp

def nested_vmap(fun, n: int):
    for _ in range(n):
        fun = jax.vmap(fun)
    return fun

def gvp(inner_fun, outer_fun, p_in, t_in):
    # p_in: pytree_0
    # t_in: pytree_1
    # inner_fun: pytree_1 -> pytree_2
    # outer_fun: pytree_1 -> pytree_2
    p_out, f_l = jax.linearize(inner_fun, p_in) # (pytree_1), (pytree_0 -> pytree_1)
    f_lt_tuple = jax.linear_transpose(f_l, p_in) # pytree_1 -> pytree_0
    f_lt = lambda x: f_lt_tuple(x)[0] # primals tuple only contain one primal
    Jt = f_l(t_in) # pytree_1
    d_outer, HJt = jax.j…

Replies: 2 comments 14 replies

Comment options

You must be logged in to vote
10 replies
@YouJiacheng
Comment options

@joeryjoery
Comment options

@joeryjoery
Comment options

@joeryjoery
Comment options

@joeryjoery
Comment options

Comment options

You must be logged in to vote
4 replies
@YouJiacheng
Comment options

@YouJiacheng
Comment options

@joeryjoery
Comment options

@YouJiacheng
Comment options

Answer selected by joeryjoery
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants