Replies: 1 comment
-
Okay, I was playing around with Traceback (most recent call last):
File "paint.py", line 46, in <module>
print(make_jaxpr(jacfwd(fun))(particles))
File "paint.py", line 31, in paint
jnp.zeros(mesh_size),
File "~/.local/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 3665, in zeros
return lax.full(shape, 0, dtype)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>. I added the mesh array as an argument, By checking the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, I have two questions and I'm a bit confused about autodiff.
Is there an easy way to calculate Jacobian tensors of order 3 or higher?
Take for example this function that has as input an array of positions (Nx3 matrix) and outputs a 3D mesh (MxMxM tensor).
If I write the following, the process gets killed because it eats up all the memory.
But if I wrap the
paint
function around another function that returns a scalar I can calculate the gradient without problems.So we get to the second question, what's really happening here? Why does this "work" while the previous code block hangs? Isn't jax chain ruling through the
paint
function?Beta Was this translation helpful? Give feedback.
All reactions