-
Hi, sorry if this is a naive question, but I can't seem to find much about it anywhere (though there is some related discussion on blackjax). I have a pytorch function that I wrapped to convert its output to jax arrays, and have defined a custom vjp for it that just wraps the pytorch autodiff gradient. Calling this function or its gradient is fine, but when trying to sample it in blackjax I find:
The reason I am asking here is that I have more of a JAX question that might help me understand what is wrong with this- what is the meaning of the "level" in |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
The level is not particularly relevant here – the issue is that your function is being called within JIT-compiled code, and so you're attempting to convert a traced variable to a numpy array, and this is not possible. You can get the same error by doing something like this: import jax
import numpy as np
import jax.numpy as jnp
@jax.jit
def call_numpy_sin(x):
x_np = np.asarray(x)
out_np = np.sin(x_np)
return jnp.asarray(out_np)
call_numpy_sin(jnp.arange(10)) To understand what's going on here, you might read through How To Think In JAX, in particular the JIT Mechanics section. The short version of the issue is, you can't convert a traced variable to a numpy array. I don't know of any good example of interoperability between JAX and pytorch within JIT-compiled code, but you may be able to hack it together using the Hope that helps! |
Beta Was this translation helpful? Give feedback.
The level is not particularly relevant here – the issue is that your function is being called within JIT-compiled code, and so you're attempting to convert a traced variable to a numpy array, and this is not possible. You can get the same error by doing something like this:
To understand what's going on here, you might read through How To Think In JAX, in particular the JIT Mechanics section.
The short version of the issue is, you can't convert a traced variable to a numpy array.
I don't know of a…