xmap properly compiles but then raises an axis size error #14879
Unanswered
cmunna0052
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am not sure if this is a bug or simply a misunderstanding on my part. I am trying to shard a feed-forward network with arbitrary hidden layers using xmap. My goal is the following: Each weight matrix is split up into 10 groups of columns (corresponding to 10 devices). At each step, the current x-vector is sent in full to all 10 devices. On each device it is multiplied into 1/10th of the weight matrix columns, added to 1/10th of bias vector, and then put through the activation to get 1/10th of the next x-vector. Then, jax.lax.all_gather is called to recombine the sharded x-vectors. This repeats until the end. Here is the code:
When I run this code, it completes the compilation of forward (I can tell by adding print statements throughout), but then in the actual calculation it fails at
assert axis_size == frame_size, "axis size doesn't match"
Does anyone see what is going on here? Is there a better/easier way to do this? I can see one solution where I split up the parameter matrices manually, but I thought it would be possible to have xmap handle that for me.
Beta Was this translation helpful? Give feedback.
All reactions