custom_jvp calculations with and without jax.jit #18322
-
Hello, I have a question about how jax will compute Consider the following function def expensive_computation(x : jax.Array, key : jax.random.PRNGKey) -> jax.Array:
Does some operations that are slow...
return output
@jax.custom_jvp
def forward(x : jax.Array, costs : jax.Array, key : jax.random.PRNGKey) -> jax.Array:
v = expensive_computation(x, key)
return jnp.dot(costs, v) Lets assume for simplicity that we are only ever going to take the derivative of @forward.defjvp
def forward_jvp(primals, tangents):
x, costs, key = primals
tangent_x, _ , _ = tangents # only tangent for x is needed.
# recompute the primal_out
v = expensive_computation(x, key)
primals_out = jnp.dot(costs, v)
# compute the tangent_out corresponding to x
tangent_out_x = jnp.dot(v, tangent_x)
# no need for calculating tangent_out for costs and key.
tangent_out = (tangent_out_x, None, None)
return primal_out, tangent_out My questions are the following:
Another way to implement this def g(tangent_x, primal_out, x, costs, key):
"""
custom_jvp for forward with respect to first variable x
"""
return jnp.dot(expensive_computation(x, key), tangent_x)
forward.def_jvps(
g,
None,
None
)
This differs from the previous
Many thanks for your help. Please do let me know if any part of the question is unclear. |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 7 replies
-
If you want to figure out exactly which computations are being performed in a given jax operations, I'd suggest using But there's no reason for me to do that: it's easy enough for you to do yourself! Hope that helps. |
Beta Was this translation helpful? Give feedback.
-
One reason you're seeing redundant computation is due to the use of the The considerations of JVP versus VJP are separate. It is often a good idea to customize JVPs when possible, so that your code supports forward-mode AD. I find that it is often easier to write JVPs as well. Still, you can avoid these common subexpressions in either setting. |
Beta Was this translation helpful? Give feedback.
-
Just coming back to this discussion to clarify my issue in the simplest possible example. I have tried to make this comment self-contained, so there is no need to refer to the above conversation. Image we have a function Imagine that part of the computation of the gradient of Goal : We would like to implement the custom gradient in Jax so that the variable corresponding to the expensive calculation is saved for the backward, so we avoid repeating the expensive calculation when calling Lets consider the following toy example, where as @jakevdp suggested we represent the expensive calculation by For clarity, I have provided an implementation using both import jax
import numpy as np
import jax.numpy as jnp
from jaxlib import xla_client
@jax.custom_jvp
def fjvp(x):
s = jnp.sin(x)
return jnp.dot(s, x)
@fjvp.defjvp
def fjvp_jvp(primals, tangents):
x= primals[0]
tangent_x = tangents[0] # only tangent for x is needed.
# recompute the primal_out
s = jnp.sin(x)
primal_out = jnp.dot(s, x)
# compute the tangent_out corresponding to x
c = jnp.cos(x)
J = s + x * c
tangent_out= jnp.dot(J, tangent_x)
return primal_out, tangent_out
@jax.custom_vjp
def fvjp(x):
s = jnp.sin(x)
return jnp.dot(s, x)
def fvjp_fwd(x):
s = jnp.sin(x)
sdot = jnp.dot(s, x)
return sdot, (x, s)
def fvjp_bwd(res, g):
x, s = res
c = jnp.cos(x)
J = s + x * c
return (g * J.T, )
fvjp.defvjp(fvjp_fwd, fvjp_bwd) If we look at the jaxpr for "jaxpr for hjvp = jax.value_and_grad(jvp)"
{ lambda ; a:f32[3]. let
b:f32[3] = sin a
c:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] b a
d:f32[3] = cos a
e:f32[3] = mul a d
f:f32[3] = add b e
g:f32[3] = dot_general[dimension_numbers=(([], []), ([], []))] 1.0 f
in (c, g) }
"jaxpr for hvjp = jax.value_and_grad(fvjp)"
{ lambda ; a:f32[3]. let
b:f32[3] = sin a
c:f32[] = dot_general[dimension_numbers=(([0], [0]), ([], []))] b a
d:f32[3] = cos a
e:f32[3] = mul a d
f:f32[3] = add b e
g:f32[3] = mul 1.0 f
in (c, g) } Just as we wanted, the However, the problem occurs when calling Using Below are png files depicting the HLO computation graph: HLO graph of HLO graph of As we can see, when performing optimizations the XLA compiler leads to repeat calls of Is there some way of structuring my jax code to avoid the recomputation occurring when calling Many thanks for your help. |
Beta Was this translation helpful? Give feedback.
I suspect in this case, the compiler noticed that
sin
is actually not that expensive a computation, and decided that it would be optimal to not save the intermediate value for use in both fusions.