Lower latency associative scan option #10599
Unanswered
oliverdutton
asked this question in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In one of my problems the implementation was bottlenecked by a cumulative matmul. JAX has a handy implemention of a work-efficient associative scan for this
lax.associative_scan
. This reduces the procedure from N steps to 2 log_2{N}-2 steps. There is a work-inefficient implementation that reduces this to log_2{N} steps, shown below which is faster for small problem sizes where the GPU is not saturated.Would anyone else be interested in having a work-inefficient option in the lax.associative scan? If so I can put together a pull request.
see https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda, http://www.cs.cmu.edu/~guyb/papers/Ble93.pdf, https://en.wikipedia.org/wiki/Prefix_sum
Beta Was this translation helpful? Give feedback.
All reactions