-
(posted on behalf of a user) I was testing jax2tf (works very well btw. I'm able to save a model trained in JAX and continue training in TF) and I could not figure out how to convert and save the fwd pass of a model with a dropout layer. The fwd pass function has a
Do I have to convert this function twice, once with train=True and once with Train=False ? Is that the only way? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It helps to think of
|
Beta Was this translation helpful? Give feedback.
It helps to think of
jax2tf.convert
having the same limitations asjax.jit
. Hence, you have two options, similar to the options you'd have if you want to jit the function in JAX:train
parameter is used in Python conditionals, you have to convert the function twice (just like you'd have to jit it twice, either explicitly, or implicitly by using thestatic_argnums
parameter forjax.jit
). Otherwise, you will get theConcretizationError
(both fromjax.jit
andjax2tf.convert
). For example,