Skip to content
Discussion options

You must be logged in to vote

A late reply, but just in case here's the fix:
jax.remat got subsumed by jax.checkpoint which also supports static_argnums:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.checkpoint, static_argnums=1) # new jax.checkpoint API supports static_argnums
def fun(x, f):
    y = f(x)
    return jnp.sum(jnp.square(y))

fun_grad = jax.jit(jax.grad(fun), static_argnums=1)  

x = jnp.arange(10.)
fun_grad(x, jnp.sin)
DeviceArray([ 0.        ,  0.90929735, -0.7568025 , -0.2794155 ,
              0.98935825, -0.5440211 , -0.5365729 ,  0.9906074 ,
             -0.2879033 , -0.75098723], dtype=float32)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by sharadmv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants