Skip to content
Discussion options

You must be logged in to vote

You can try to use jax.tree_util.Partial to wrap your function.

Passing zero arguments to Partial effectively wraps the original function, making it a valid argument in JAX transformed functions

Alternatively, you can use static_argnums or static_argnames in jit, i.e. use jax_func = jax.jit(func, static_argnums=0)

Or you can use closure:

from jax.experimental import jax2tf
import jax
import tensorflow as tf


def cos_tf(x):
    return tf.math.cos(x)

jax_cos = jax2tf.call_tf(cos_tf)

def make_jax_func(recurrent_net):
    def func(x): # I guess you actually need some inputs here, not just recurrent_net
        return recurrent_net(x) + x # do sth with recurrent_net
    return jax.jit(func)…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@hejujie
Comment options

Answer selected by hejujie
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants