You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Pipeliner] Implement more sophisticated partitioning strategy for attention (#6660)
This PR implements a more sophisticated partitioning and scheduling
strategy, aimed at generating a better schedule for forward attention.
This PR enables composing the pipeliner with warp specialization by
tweaking `scheduleLoops` to consider a `ttg.assigned_stage` attribute.
Normally, the pipeline scheduler determines stages based on a DAG
between latency ops, but this allows some higher level pass to directly
inject an assigned stage into `scheduleKeyOps`. This allows, for
example, warp specialization to place `load K` into stage 0 and `load V`
into stage 2, and `MMA QK` into stage 0 and `MMA PV` into stage 2, in
their respective partitions.
This PR also reworks partition assignment to generate more than 1 vector
partition. Previously, the partitioning strategy was generalized to take
the body of the loop and "outline" async loads and async MMAs into their
own partition, keeping the rest of the loop body in the default
partition.
Now, partition assignment also considers high latency synchronous
operations, starting with `math.exp2` with a large number of elements.
It does a first-order partition assignment, placing local users and
dependencies of loads and MMAs into the same partition, and then
clusters the remaining operations together. The clusters are then
either:
1. Assigned entirely to the source partition
2. Assigned entirely to the sink partition
3. Assigned into a wholly new partition
4. Rematerialized into sink partitions along critical paths
The idea is simply to simultaneously reduce the critical path between
each latency operation. This strategy successfully automatically derives
the same schedule CUTLASS uses for FMHA (correction partition, etc.).
There is currently no cost model for deciding between rematerialization
or sending intermediates over shared memory/placing them in their own
partition, but we can probably reuse the one @apgoucher implemented for
`remove-layout-conversions` in
triton-lang/triton#6667.
This also slightly tweaks the kernel code in `06-fused-attention.py` to
reduce register pressure. This achieves close to 700 TFLOPS on DHEAD=64
and around 960-1080 TFLOPS on DHEAD=128.
TODO:
- [ ] Write lit tests for new scheduler
- [ ] Write integration tests for MFHA
---------
Co-authored-by: Chris Sullivan <[email protected]>
0 commit comments