Programming TPUs in JAX | How To Scale Your Model #12
Replies: 2 comments 5 replies
-
Hi! Thanks for the great book. I believe that Problem 1's attached solution doesn't correspond to the problem description (it's only on one axis). I have been getting an issue with my solution and have written a Github Issue. I would really appreciate if you could help me out I also have gotten another error I couldn't solve. I run this line (as described in this chapter): replicated = jax.lax.with_sharding_constraint(replicated, P(None, 'Y')) and I keep getting an error:
has the API perhaps changed? |
Beta Was this translation helpful? Give feedback.
-
I'm pretty sure this W[D_X, F] should be W[D_Y, F] |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
JAX programming, broadly construed!
Beta Was this translation helpful? Give feedback.
All reactions