Commit 26ce12c
authored
Swap FFA backward QK loop (#204)
* added vis_cute_layout utils func
* minor updated vis_cute_layout utils func
* added debug print code for fwd
* added debug print code for bwd
* fixed fwd debug print
* added swap_bwd_qk_loop to ffa args and adjusted the order of args and updated the docstring and updated the usage of ffa fwd/bwd funcs
* added SwapBwdQKLoop template into ffa bwd; passed swap_bwd_qk_loop flag through ffa jit system
* inited up the framework for ffa_bwd loop k
* moved static_switch.h and utils.h to ffa sub-dir
* added canonical_warp_idx_in_warpgroup sync/nosync to utils.h
* added canonical_thread_idx_in_warpgroup nosync/sync utils func
* added sync_cga_threads to utils; inited shared storage and pipeline init
* inited the bwd schedule func; added BwdNamedBarriersLoopK
* renamed funcs for loop-q
* partial implemented load_with_loop_k; added get_tma_multi_cast_meta, sizeof_bytes_v utils funcs
* simplified the usage of TileShape_MNK
* simplified the usage of cutlass::gemm::collective::detail
* inited the smem layout for swap_qk_loop; polished much
* inited the params for swap_qk_loop
* removed runtime debug code
* minor polished load func for fwd/bwd
* refactored load func for fwd/bwd
* implemented the load func (i.e. producer) for swap_qk_loop
* rich-commented scheduling code for fwd/bwd; inited the consumer scheduling code for swap_qk_loop
* removed all the left debug code
* minor polished bwd mma func
* added dKV acc smem layout and tensor storage
* updated bwd args when switching swap qk loop; unified /* DEBUG */ tag
* added some softmax helper funcs; half-implemented mma func until P,dS
* repolished the comments of bwd mma func
* implemented dv gemm
* rich-commented the second half of bwd mma when not swap_qk_loop
* finished mma func for swap_qk_loop w/o Slice_dQKV_Mma
* implemented mma func when swap_qk_loop with Slice_dQKV_Mma
* removed temp debug signature
* implemented store_kv func
* implemented epilogue store funcs
* make auto_range_merge a jit template parameter
* minor fixed some compilation error
* found the named barriers issue and added static assertation
* implemented barrier manager
* added barrier traits and extended named barrier sync/arrive with raw barrier IDs; deleted flash::named_barrier_xxx APIs
* fixed another out-of-smem-limit issue from {64, 128, 128} to {64, 64, 128}
* minor fixed kwargs of Seqlenk_mask
* fixed scheduler_args from k ranges to q ranges when swap_qk_loop
* minor fixed a typo
* minor fixed layout idx but left fixme when not using tma
* added temp debug code to align the tile size with before, reducing shared memory usage by ignoring dk
* fixed the store dkv bidh -> bidh_kv bug
* updated temp debug code to align the tile size with before, reducing shared memory usage by ignoring dv, instead of ignoring dk
* fixed hung bug for get_lse_scaled when not all lanes get into it
* removed softcap=True from prebuild to shorten compilation time
* added temp debug code for test ffa
* removed all the debug code and adjusted the atom layouts
* fixed the missing SwapBwdQKLoop template param for tile_size_bwd_sm90 in prepare_mha_bwd
* removed debug print code
* removed all remaining debug code and added swap_bwd_qk_loop to test_ffa
* fixed the disable_bwd_dkv_atomic_reduction to be only enabled with MHA
* optimized check_mask_lse only for last m block job of each batch
* minor polished tile scheduler and added count_in_warp utils func
* added swabwdpqkloop dense benchmark
* added merge csv utils script
* refactored store dq,dk,dv
* make producer storer to 2 warps; added some static assertion for swap qk loop
* renamed atom layout
* added one comment for NumProducerThreads
* updated merge_csv utils script
* renamed the file name due to typo
* merged main as one single commit
* fixed a missing arg
* adjusted the ffa arg order
* minor polished ffa fwd code
* minor fixed
* updated prebuild logics in setup.py1 parent a77026d commit 26ce12c
File tree
37 files changed
+4085
-1384
lines changed- exps
- attn
- dist_attn
- baselines
- magi_attention
- csrc/flexible_flash_attention
- functional
- meta/collection
- utils
- tests
- test_attn
37 files changed
+4085
-1384
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | | - | |
4 | 3 | | |
5 | 4 | | |
6 | 5 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
0 commit comments