Skip to content
Discussion options

You must be logged in to vote

The batching rule for a primitive takes a tuple of arguments, and a tuple of batch dims, and evaluates the batched version of the primitive. Since your primitive is implemented via normal JAX operations, you can implement the batching rule via a call to vmap. For example:

lorenz_jvp_p = core.Primitive("lorenz_jvp")

@lorenz_jvp_p.def_impl
def lorenz_jvp_impl(x, x_dot):
    return jnp.array(
        [
            x_dot[1] * SIGMA - x_dot[0] * SIGMA,
            -x[2] * x_dot[0] - x[0] * x_dot[2] - x_dot[1] + x_dot[0] * RHO,
            x[1] * x_dot[0] + x[0] * x_dot[1] - x_dot[2] * BETA,
        ]
    )

def lorenz_jvp_batching_rule(batched_args, batch_dims):
  x, x_dot = batched_args
  ba…

Replies: 1 comment 5 replies

Comment options

You must be logged in to vote
5 replies
@Hs293Go
Comment options

@jakevdp
Comment options

Answer selected by Hs293Go
@Hs293Go
Comment options

@jakevdp
Comment options

@Hs293Go
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants