Replies: 1 comment 1 reply
-
Hey, did you manage to make it work? |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm interested in implementing sliding window attention, which roughly reduces to taking two tensors
Q: [B, L, D], K: [B, L, D]
, performing a (padded) sliding window operation on K to obtainWK: [B, L, W, D]
and then dotting WK and Q alongD
, obtainingA: [B, L, W]
:(That is, we want to take sliding windows of K and for each position in L dot the appropriate window)
Doing this efficiently in pure Jax really depends on having a "sufficiently smart compiler", or rather tricking the compiler into being sufficiently smart. Based on @mattjj's comment in #3171 (comment), I was hopeful that the compiler would be sufficiently smart. However, I've tried a few different ways of performing the sliding window operation, but none of them convince XLA to do the "smart" thing.
Below is a bit of a dump of attempts. There are broadly two approaches: one explicitly constructs a padded moving window and then runs an einsum like the above. The other pushes the einsum and slice into a single vmap over the "L" dimension.
There are two failure modes: an expected one where xla wants to allocate a buffer [L, B, W, D] (64GB per device in the config I tried: L=8192, D=4096, W=512), and another where it wants to allocate a buffer of [L, B, L, D] (1TB per device). I dunno why it's trying to do that last mode. The former makes sense to me.
Is there something obvious I'm missing?
Beta Was this translation helpful? Give feedback.
All reactions