Replies: 1 comment 4 replies
-
The answer is sort of. So long as the value that the |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone! I've been using PyTorch but I'm new to JAX.
I'd like to ask if JAX can handle irregular functions that produces tensors of variable shapes? This may seem to be weird at the first sight, but a real use case of mine is some simulation, in which the shapes of tensors have geometric meaning.
For the sake of simplicity, take a look at the below (insane) example.
To make this example less insane, you can think of another example, which is training a fully-convolutional auto-encoder on images of different sizes without paddings.
This example is over-simplified, but you can see it's differentiable and a pure function, despite the algorithm generates irregular tensors of different shapes. In theory, it's totally fine to do the compute graph construction in parallel with Pytorch's Autograd, but I cannot really do the
map
in parallel with it because of GIL, so I wonder if JAX can do this and gives me correct gradients. Thanks a lot!Beta Was this translation helpful? Give feedback.
All reactions