How to use custom_vjp as the inputs of jitted function when using call_tf ? #10118
-
I want to use a TensorFlow function in JAX with jit mode, so I choose jax2tf.call_tf to translate my TensorFlow function and pass it as an input of the jitted function.
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can try to use jax.tree_util.Partial to wrap your function.
Alternatively, you can use Or you can use closure: from jax.experimental import jax2tf
import jax
import tensorflow as tf
def cos_tf(x):
return tf.math.cos(x)
jax_cos = jax2tf.call_tf(cos_tf)
def make_jax_func(recurrent_net):
def func(x): # I guess you actually need some inputs here, not just recurrent_net
return recurrent_net(x) + x # do sth with recurrent_net
return jax.jit(func)
jax_func = make_jax_func(jax_cos)
jax_func(1.0) BTW, |
Beta Was this translation helpful? Give feedback.
You can try to use jax.tree_util.Partial to wrap your function.
Alternatively, you can use
static_argnums
orstatic_argnames
injit
, i.e. usejax_func = jax.jit(func, static_argnums=0)
Or you can use closure: