-
Edit so that the bottom line is up front: The question is:
We have extensive docs on defining custom derivative rules and a helpful discussion #12730 (comment) on defining a derivative rules that is a primitive itself. All of these resources test their derivative rules with The same could not be said for the code in (#12730 (comment)). Replacing its last line
This is quite astonishing, since batching rules and derivative rules seem to be orthogonal to my best knowledge; They are never mentioned together in the same context by the docs. Whereas the primary source introducing batching rules How JAX primitives work demonstrates batching a function itself, not its jvp/vjp rule.
I'm including a brief demo of what is working and what's not based on the code in #12730 (comment) that EDIT 2: Per #19973, remove all uses of import jax
import jax.numpy as jnp
import numdifftools
import numpy as onp
from jax import core
from jax.interpreters import ad
# == SETUP PROBLEM: ADAPTED FROM 'Question about defining new JAX primitives #12730' ===
# Make a Primitive
lorenz_p = core.Primitive("lorenz")
# EDIT 2: Do NOT use @jax.custom_jvp
def lorenz(x):
return lorenz_p.bind(x)
# hardcode lorenz system parameters for simplicity
SIGMA, RHO, BETA = 28.0, 10.0, 8 / 3
@lorenz_p.def_impl
def lorenz_impl(x):
""" Lorenz system; Dynamical system is inherently vector-valued,
instead of vector-valued through vectorization
"""
return onp.array(
[SIGMA * (x[1] - x[0]), x[0] * (RHO - x[2]) - x[1], x[0] * x[1] - BETA * x[2]]
)
# EDIT 2: Do NOT use @lorenz.defjvp
def lorenz_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
y = lorenz(x)
y_dot = lorenz_jvp_p.bind(x, xdot)
return y, y_dot
lorenz_jvp_p = core.Primitive("lorenz_jvp")
#EDIT 2: Do use
ad.primitive_jvps[lorenz_p] = lorenz_jvp
@lorenz_jvp_p.def_impl
def lorenz_jvp_impl(x, x_dot):
return onp.array(
[
x_dot[1] * SIGMA - x_dot[0] * SIGMA,
-x[2] * x_dot[0] - x[0] * x_dot[2] - x_dot[1] + x_dot[0] * RHO,
x[1] * x_dot[0] + x[0] * x_dot[1] - x_dot[2] * BETA,
]
)
@lorenz_jvp_p.def_abstract_eval
def lorenz_jvp_abstract_eval(_, x_dot_aval):
y_dot_aval = core.ShapedArray(x_dot_aval.shape, x_dot_aval.dtype)
return y_dot_aval
def lorenz_jvp_transpose(y_bar, x, x_dot_dummy):
assert ad.is_undefined_primal(x_dot_dummy) # just a dummy input
x_bar = lorenz_vjp_p.bind(x, y_bar) # y_bar aka y_grad
return None, x_bar # None for nonlinear primal input x
ad.primitive_transposes[lorenz_jvp_p] = lorenz_jvp_transpose
# Finally, let's write the vjp rule as a primitive.
lorenz_vjp_p = core.Primitive("lorenz_vjp")
@lorenz_vjp_p.def_impl
def lorenz_vjp_impl(x, v):
return onp.array(
[
v[2] * x[1] + v[1] * (RHO - x[2]) - v[0] * SIGMA,
v[2] * x[0] - v[1] + v[0] * SIGMA,
-v[2] * BETA - v[1] * x[0],
]
)
test_primal = jnp.array([1.0, 0.0, 0.0])
test_tangent = jnp.array([0.0, 1.0, 0.01])
# ============================= BASIC BEHAVIOR IS CORRECT ==============================
_, result = jax.jvp(lorenz, [test_primal], [test_tangent]) # Step 1: JVP
expected = numdifftools.Jacobian(lorenz)(test_primal) @ test_tangent
assert jnp.allclose(result, expected)
_, f_vjp = jax.vjp(lorenz, test_primal) # Step 2: VJP
(result,) = f_vjp(test_tangent)
expected = test_tangent @ numdifftools.Jacobian(lorenz)(test_primal)
assert jnp.allclose(result, expected)
def scalar_val_test_fn(x):
return jnp.hypot(*lorenz(x + onp.array([1.0, 2.0, 0.0]))[1::-1])
# Step 3: Autodiff a function that nests 'lorenz'
result = jax.grad(scalar_val_test_fn)(test_primal)
expected = numdifftools.Gradient(scalar_val_test_fn)(test_primal)
assert jnp.allclose(result, expected, rtol=1e-4) # type: ignore
# ========================= CAN TAKE NEITHER GRAD NOR JACOBIAN =========================
try:
_ = jax.grad(lorenz)(test_primal) # Cannot take grad
except TypeError as e:
print(e) # Gradient only defined for scalar-output functions.
try:
_ = jax.jacfwd(lorenz)(test_primal) # Cannot take jacobian
except NotImplementedError as e:
print(e) # Batching rule for 'lorenz_jvp' not implemented
try:
_ = jax.jacrev(lorenz)(test_primal) # Cannot take jacobian
except NotImplementedError as e:
print(e) # Batching rule for 'lorenz_vjp' not implemented |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 5 replies
-
Hi - as a brief answer to your question:
You'll need to define the batching rule for this jvp rule in order to use
Stepping back, though, you'll need to start over on your implementation. You're confusing two different concepts: primitive jvp rules and custom jvp rules. The Here it's not clear what you're gaining from defining a custom primitive. A good way forward might be to define your functions as normal Python functions, wrapped in |
Beta Was this translation helpful? Give feedback.
The batching rule for a primitive takes a tuple of arguments, and a tuple of batch dims, and evaluates the batched version of the primitive. Since your primitive is implemented via normal JAX operations, you can implement the batching rule via a call to
vmap
. For example: