Autodifferentiation through parallelized operators with xmap #14982
Unanswered
cmunna0052
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
This question concerns the same basic setup as my previous one #14879, but with a slightly different approach to sharding that gets a bit further before breaking. I am still trying to shard a feed-forward network on MNIST dataset by splitting up the weight matrices into 10 groups of columns. However, now I have defined an xmapped matrix multiplication operation on its own, with the following code:
This code works (though I am not sure why I don't have to recombine the x vector at each step with jax.lax.all_gather, and in fact doing so causes an error).
Now the problem comes at the next step, where I try to backpropagate with the following:
This correctly calculates the loss but fails in the train_batch portion with the error
assert len(arg) == n, f'length mismatch: [6,6,2]'
. I added in the custom_vjp because the regular jax.lax.pmean was throwing a similar error, and I assumed it wasn't a differentiable operator anyway. Any ideas on how I should get through this?Beta Was this translation helpful? Give feedback.
All reactions