-
I implemented a neural network that has some pallas kernels in it. I wonder what will happen to the kernels if I jit the entire model training step on a multi-GPU system? Will they they also be automatically partitioned? |
Beta Was this translation helpful? Give feedback.
Answered by
sharadmv
Feb 11, 2024
Replies: 1 comment
-
Pallas kernels won't be autopartitioned, though we are brainstorming APIs to enable adding hooks into the partitioner. The preferred solution is to use |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
luyug
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Pallas kernels won't be autopartitioned, though we are brainstorming APIs to enable adding hooks into the partitioner. The preferred solution is to use
jax.experimental.shard_map
.