jax.remat with static_argnums #9153
-
Dear jax-Team, I want to use import jax
import jax.numpy as jnp
@jax.remat
def fun(x, f):
y = f(x)
return jnp.sum(jnp.square(y))
fun_grad = jax.jit(jax.grad(fun), static_argnums=1) # same without JIT
x = arange(10.)
fun_grad(x, jnp.sin) gives
|
Beta Was this translation helpful? Give feedback.
Answered by
niklasschmitz
Sep 24, 2022
Replies: 1 comment
-
A late reply, but just in case here's the fix: import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.checkpoint, static_argnums=1) # new jax.checkpoint API supports static_argnums
def fun(x, f):
y = f(x)
return jnp.sum(jnp.square(y))
fun_grad = jax.jit(jax.grad(fun), static_argnums=1)
x = jnp.arange(10.)
fun_grad(x, jnp.sin)
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
sharadmv
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A late reply, but just in case here's the fix:
jax.remat
got subsumed byjax.checkpoint
which also supports static_argnums: