Skip to content
Discussion options

You must be logged in to vote

It helps to think of jax2tf.convert having the same limitations as jax.jit. Hence, you have two options, similar to the options you'd have if you want to jit the function in JAX:

  • in most cases, when the train parameter is used in Python conditionals, you have to convert the function twice (just like you'd have to jit it twice, either explicitly, or implicitly by using the static_argnums parameter for jax.jit). Otherwise, you will get the ConcretizationError (both from jax.jit and jax2tf.convert). For example,
    def f_jax_1(x, train):
      if train:  # Note the Python conditional
        return x
      else:
        return x * 2.

    # jax.jit(f_jax_1)(2., True)  # --> Concretization…

Replies: 1 comment

Comment options

gnecula
Mar 25, 2022
Maintainer Author

You must be logged in to vote
0 replies
Answer selected by gnecula
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant