How to use pjit to speed up divide and conquer algorithm? #9361
Unanswered
AdrienCorenflos
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a divide and conquer algorithm that essentially generalises prefix sums (
jax.lax.associative_scan
) to non associative operators whilst retaining the logarithmic span complexity (an logarithmic compile time) of prefix sums in the size of the input data. The algorithm, written in an in-place fashion goes roughly like this:While it works really well on a single GPU, the logarithmic scaling stops happening when
T
becomes too big. Intuitively I tend think that it can be made compatible withpjit
to make the scaling continue past the number of threads on a given GPU. However, I am at a complete loss when trying to use it. I have been trying different combinations of meshes and partition specs, to, so far, no avail (see figure).I would very much appreciate (and acknowledge in coming paper using this algorithm) if someone could help me fix this issue. My code is available on Colab (although I think TPU pods are not compatible with PJIT on Colab yet?) at the following address https://colab.research.google.com/drive/1frM5UgGlmky2nbpJCvPSkQIzgSt9hCis?usp=sharing. The tentative PJIT is in the function
dc_map_pjit
.Also, on a side note, if the JAX team is interested, I could contribute the divide and conquer algorithm to the code base.
Thanks a lot to whomever will be kind enough to read up to here
Adrien
Beta Was this translation helpful? Give feedback.
All reactions